mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
184 Commits
fix/copilo
...
abhi/add-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54bf45656a | ||
|
|
2f32217c7c | ||
|
|
7b64fbc931 | ||
|
|
1a0234c946 | ||
|
|
048fb06b0a | ||
|
|
1e14634d3d | ||
|
|
3f653e6614 | ||
|
|
c9c3d54b2b | ||
|
|
53d58e21d3 | ||
|
|
fa04fb41d8 | ||
|
|
d9c16ded65 | ||
|
|
6dc8429ae7 | ||
|
|
cfe22e5a8f | ||
|
|
a8259ca935 | ||
|
|
1f1288d623 | ||
|
|
02645732b8 | ||
|
|
ba301a3912 | ||
|
|
0cd9c0d87a | ||
|
|
a083493aa2 | ||
|
|
c51dc7ad99 | ||
|
|
bc6b82218a | ||
|
|
83e49f71cd | ||
|
|
ef446e4fe9 | ||
|
|
7b1e8ed786 | ||
|
|
7ccfff1040 | ||
|
|
81c7685a82 | ||
|
|
3595c6e769 | ||
|
|
1c2953d61b | ||
|
|
755bc84b1a | ||
|
|
ade2baa58f | ||
|
|
4d35534a89 | ||
|
|
2cc748f34c | ||
|
|
c2e79fa5e1 | ||
|
|
89a5b3178a | ||
|
|
c62d9a24ff | ||
|
|
0e0bfaac29 | ||
|
|
0633475915 | ||
|
|
34a2f9a0a2 | ||
|
|
9f4caa7dfc | ||
|
|
0876d22e22 | ||
|
|
15e3980d65 | ||
|
|
fe9eb2564b | ||
|
|
5641cdd3ca | ||
|
|
bfb843a56e | ||
|
|
684845d946 | ||
|
|
6a6b23c2e1 | ||
|
|
d0a1d72e8a | ||
|
|
f1945d6a2f | ||
|
|
6491cb1e23 | ||
|
|
c7124a5240 | ||
|
|
5537cb2858 | ||
|
|
aef5f6d666 | ||
|
|
8063391d0a | ||
|
|
0bbb12d688 | ||
|
|
eadc68f2a5 | ||
|
|
19d775c435 | ||
|
|
eca7b5e793 | ||
|
|
c304a4937a | ||
|
|
7ead4c040f | ||
|
|
8cfabcf4fd | ||
|
|
7bf407b66c | ||
|
|
0f813f1bf9 | ||
|
|
aa08063939 | ||
|
|
bde6a4c0df | ||
|
|
d56452898a | ||
|
|
7507240177 | ||
|
|
d7c3f5b8fc | ||
|
|
3e108a813a | ||
|
|
08c49a78f8 | ||
|
|
5d56548e6b | ||
|
|
6ecf55d214 | ||
|
|
7c8c7bf395 | ||
|
|
0b9e0665dd | ||
|
|
be18436e8f | ||
|
|
f6f268a1f0 | ||
|
|
ea0333c1fc | ||
|
|
21c705af6e | ||
|
|
a576be9db2 | ||
|
|
5e90585f10 | ||
|
|
3e22a0e786 | ||
|
|
6abe39b33a | ||
|
|
476cf1c601 | ||
|
|
25022f2d1e | ||
|
|
ce1675cfc7 | ||
|
|
3d0ede9f34 | ||
|
|
5474f7c495 | ||
|
|
f1b771b7ee | ||
|
|
aa7a2f0a48 | ||
|
|
3722d05b9b | ||
|
|
592830ce9b | ||
|
|
6cc680f71c | ||
|
|
b342bfa3ba | ||
|
|
0215332386 | ||
|
|
160d6eddfb | ||
|
|
a5db9c05d0 | ||
|
|
b74d41d50c | ||
|
|
a897f9e124 | ||
|
|
7fd26d3554 | ||
|
|
b504cf9854 | ||
|
|
29da8db48e | ||
|
|
757ec1f064 | ||
|
|
9442c648a4 | ||
|
|
1c51dd18aa | ||
|
|
6f4f80871d | ||
|
|
e8cca6cd9a | ||
|
|
bf6308e87c | ||
|
|
4e59143d16 | ||
|
|
d5efb6915b | ||
|
|
b9aac42056 | ||
|
|
95651d33da | ||
|
|
b30418d833 | ||
|
|
ed729ddbe2 | ||
|
|
8c7030af0b | ||
|
|
195b14286a | ||
|
|
29ca034e40 | ||
|
|
1d9dd782a8 | ||
|
|
a1cb3d2a91 | ||
|
|
1b91327034 | ||
|
|
c7cdb40c5b | ||
|
|
77fb4419d0 | ||
|
|
9f002ce8f6 | ||
|
|
74691076c6 | ||
|
|
b15ad0df9b | ||
|
|
2136defea8 | ||
|
|
6e61cb103c | ||
|
|
0e72e1f5e7 | ||
|
|
163b0b3c9d | ||
|
|
ef42b17e3b | ||
|
|
a18ffd0b21 | ||
|
|
e40c8c70ce | ||
|
|
9cdcd6793f | ||
|
|
fc64f83331 | ||
|
|
7718c49f05 | ||
|
|
0a1591fce2 | ||
|
|
681bb7c2b4 | ||
|
|
0818cd6683 | ||
|
|
7a39bdfaf8 | ||
|
|
0b151f64e8 | ||
|
|
be2a48aedb | ||
|
|
aeca4dbb79 | ||
|
|
7b85eeaae2 | ||
|
|
4db3be2d61 | ||
|
|
f57a1995d0 | ||
|
|
3928c35928 | ||
|
|
dc77e7b6e6 | ||
|
|
ba75cc28b5 | ||
|
|
15bcdae4e8 | ||
|
|
e9ba7e51db | ||
|
|
d23248f065 | ||
|
|
905373a712 | ||
|
|
ee9d39bc0f | ||
|
|
05aaf7a85e | ||
|
|
9d4dcbd9e0 | ||
|
|
074be7aea6 | ||
|
|
39d28b24fc | ||
|
|
bf79a7748a | ||
|
|
649d4ab7f5 | ||
|
|
223df9d3da | ||
|
|
187ab04745 | ||
|
|
e2d3c8a217 | ||
|
|
647c8ed8d4 | ||
|
|
27d94e395c | ||
|
|
b8f5c208d0 | ||
|
|
ca216dfd7f | ||
|
|
f9f358c526 | ||
|
|
52b3aebf71 | ||
|
|
965b7d3e04 | ||
|
|
c2368f15ff | ||
|
|
9ac3f64d56 | ||
|
|
5035b69c79 | ||
|
|
86af8fc856 | ||
|
|
dfa517300b | ||
|
|
43b25b5e2f | ||
|
|
ab0b537cc7 | ||
|
|
9a8c6ad609 | ||
|
|
e8c50b96d1 | ||
|
|
30e854569a | ||
|
|
301d7cbada | ||
|
|
d95aef7665 | ||
|
|
cb166dd6fb | ||
|
|
3d31f62bf1 | ||
|
|
b8b6c9de23 | ||
|
|
4f6055f494 | ||
|
|
695a185fa1 |
79
.claude/skills/pr-address/SKILL.md
Normal file
79
.claude/skills/pr-address/SKILL.md
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
---
|
||||||
|
name: pr-address
|
||||||
|
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
||||||
|
user-invocable: true
|
||||||
|
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||||
|
metadata:
|
||||||
|
author: autogpt-team
|
||||||
|
version: "1.0.0"
|
||||||
|
---
|
||||||
|
|
||||||
|
# PR Address
|
||||||
|
|
||||||
|
## Find the PR
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||||
|
gh pr view {N}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fetch comments (all sources)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
|
||||||
|
```
|
||||||
|
|
||||||
|
**Bots to watch for:**
|
||||||
|
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
|
||||||
|
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
|
||||||
|
- `coderabbitai[bot]` — automated review. Address actionable items.
|
||||||
|
|
||||||
|
## For each unaddressed comment
|
||||||
|
|
||||||
|
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
||||||
|
|
||||||
|
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
||||||
|
2. Commit and push the fix
|
||||||
|
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
||||||
|
|
||||||
|
| Comment type | How to reply |
|
||||||
|
|---|---|
|
||||||
|
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
|
||||||
|
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
|
||||||
|
|
||||||
|
## Format and commit
|
||||||
|
|
||||||
|
After fixing, format the changed code:
|
||||||
|
|
||||||
|
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
|
||||||
|
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
|
||||||
|
|
||||||
|
If API routes changed, regenerate the frontend client:
|
||||||
|
```bash
|
||||||
|
cd autogpt_platform/backend && poetry run rest &
|
||||||
|
REST_PID=$!
|
||||||
|
trap "kill $REST_PID 2>/dev/null" EXIT
|
||||||
|
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
|
||||||
|
cd ../frontend && pnpm generate:api:force
|
||||||
|
kill $REST_PID 2>/dev/null; trap - EXIT
|
||||||
|
```
|
||||||
|
Never manually edit files in `src/app/api/__generated__/`.
|
||||||
|
|
||||||
|
Then commit and **push immediately** — never batch commits without pushing.
|
||||||
|
|
||||||
|
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
||||||
|
|
||||||
|
## The loop
|
||||||
|
|
||||||
|
```text
|
||||||
|
address comments → format → commit → push
|
||||||
|
→ re-check comments → fix new ones → push
|
||||||
|
→ wait for CI → re-check comments after CI settles
|
||||||
|
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
||||||
|
```
|
||||||
|
|
||||||
|
While CI runs, stay productive: run local tests, address remaining comments.
|
||||||
|
|
||||||
|
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.
|
||||||
74
.claude/skills/pr-review/SKILL.md
Normal file
74
.claude/skills/pr-review/SKILL.md
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
---
|
||||||
|
name: pr-review
|
||||||
|
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
|
||||||
|
user-invocable: true
|
||||||
|
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||||
|
metadata:
|
||||||
|
author: autogpt-team
|
||||||
|
version: "1.0.0"
|
||||||
|
---
|
||||||
|
|
||||||
|
# PR Review
|
||||||
|
|
||||||
|
## Find the PR
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
||||||
|
gh pr view {N}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Read the diff
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh pr diff {N}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fetch existing review comments
|
||||||
|
|
||||||
|
Before posting anything, fetch existing inline comments to avoid duplicates:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
||||||
|
```
|
||||||
|
|
||||||
|
## What to check
|
||||||
|
|
||||||
|
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
||||||
|
|
||||||
|
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
||||||
|
|
||||||
|
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
|
||||||
|
|
||||||
|
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
|
||||||
|
|
||||||
|
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
|
||||||
|
|
||||||
|
## Output format
|
||||||
|
|
||||||
|
Every comment **must** be prefixed with `🤖` and a criticality badge:
|
||||||
|
|
||||||
|
| Tier | Badge | Meaning |
|
||||||
|
|---|---|---|
|
||||||
|
| Blocker | `🔴 **Blocker**` | Must fix before merge |
|
||||||
|
| Should Fix | `🟠 **Should Fix**` | Important improvement |
|
||||||
|
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
|
||||||
|
| Nit | `🔵 **Nit**` | Style / wording |
|
||||||
|
|
||||||
|
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
|
||||||
|
|
||||||
|
## Post inline comments
|
||||||
|
|
||||||
|
For each finding, post an inline comment on the PR (do not just write a local report):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Get the latest commit SHA for the PR
|
||||||
|
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
|
||||||
|
|
||||||
|
# Post an inline comment on a specific file/line
|
||||||
|
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
|
||||||
|
-f body="🤖 🔴 **Blocker**: <description>" \
|
||||||
|
-f commit_id="$COMMIT_SHA" \
|
||||||
|
-f path="<file path>" \
|
||||||
|
-F line=<line number>
|
||||||
|
```
|
||||||
85
.claude/skills/worktree/SKILL.md
Normal file
85
.claude/skills/worktree/SKILL.md
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
---
|
||||||
|
name: worktree
|
||||||
|
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
||||||
|
user-invocable: true
|
||||||
|
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
|
||||||
|
metadata:
|
||||||
|
author: autogpt-team
|
||||||
|
version: "3.0.0"
|
||||||
|
---
|
||||||
|
|
||||||
|
# Worktree Setup
|
||||||
|
|
||||||
|
## Create the worktree
|
||||||
|
|
||||||
|
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ROOT=$(git rev-parse --show-toplevel)
|
||||||
|
PARENT=$(dirname "$ROOT")
|
||||||
|
|
||||||
|
# From an existing branch
|
||||||
|
git worktree add "$PARENT/<NAME>" <branch-name>
|
||||||
|
|
||||||
|
# From a new branch off dev
|
||||||
|
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
|
||||||
|
```
|
||||||
|
|
||||||
|
## Copy environment files
|
||||||
|
|
||||||
|
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ROOT=$(git rev-parse --show-toplevel)
|
||||||
|
TARGET="$(dirname "$ROOT")/<NAME>"
|
||||||
|
|
||||||
|
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
|
||||||
|
if [ -f "$ROOT/$envpath/.env" ]; then
|
||||||
|
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
|
||||||
|
elif [ -f "$ROOT/$envpath/.env.default" ]; then
|
||||||
|
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
## Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||||
|
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
|
||||||
|
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
|
||||||
|
cd "$TARGET/autogpt_platform/frontend" && pnpm install
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
|
||||||
|
|
||||||
|
## Running the app (optional)
|
||||||
|
|
||||||
|
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||||
|
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
||||||
|
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
||||||
|
done
|
||||||
|
cd "$TARGET/autogpt_platform/backend" && poetry run app
|
||||||
|
```
|
||||||
|
|
||||||
|
## CoPilot testing
|
||||||
|
|
||||||
|
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
||||||
|
|
||||||
|
## Cleanup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
|
||||||
|
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Alternative: Branchlet (optional)
|
||||||
|
|
||||||
|
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
||||||
|
```
|
||||||
@@ -5,42 +5,13 @@
|
|||||||
!docs/
|
!docs/
|
||||||
|
|
||||||
# Platform - Libs
|
# Platform - Libs
|
||||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
!autogpt_platform/autogpt_libs/
|
||||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
|
||||||
!autogpt_platform/autogpt_libs/poetry.lock
|
|
||||||
!autogpt_platform/autogpt_libs/README.md
|
|
||||||
|
|
||||||
# Platform - Backend
|
# Platform - Backend
|
||||||
!autogpt_platform/backend/backend/
|
!autogpt_platform/backend/
|
||||||
!autogpt_platform/backend/test/e2e_test_data.py
|
|
||||||
!autogpt_platform/backend/migrations/
|
|
||||||
!autogpt_platform/backend/schema.prisma
|
|
||||||
!autogpt_platform/backend/pyproject.toml
|
|
||||||
!autogpt_platform/backend/poetry.lock
|
|
||||||
!autogpt_platform/backend/README.md
|
|
||||||
!autogpt_platform/backend/.env
|
|
||||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
|
||||||
|
|
||||||
# Platform - Market
|
|
||||||
!autogpt_platform/market/market/
|
|
||||||
!autogpt_platform/market/scripts.py
|
|
||||||
!autogpt_platform/market/schema.prisma
|
|
||||||
!autogpt_platform/market/pyproject.toml
|
|
||||||
!autogpt_platform/market/poetry.lock
|
|
||||||
!autogpt_platform/market/README.md
|
|
||||||
|
|
||||||
# Platform - Frontend
|
# Platform - Frontend
|
||||||
!autogpt_platform/frontend/src/
|
!autogpt_platform/frontend/
|
||||||
!autogpt_platform/frontend/public/
|
|
||||||
!autogpt_platform/frontend/scripts/
|
|
||||||
!autogpt_platform/frontend/package.json
|
|
||||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
|
||||||
!autogpt_platform/frontend/tsconfig.json
|
|
||||||
!autogpt_platform/frontend/README.md
|
|
||||||
## config
|
|
||||||
!autogpt_platform/frontend/*.config.*
|
|
||||||
!autogpt_platform/frontend/.env.*
|
|
||||||
!autogpt_platform/frontend/.env
|
|
||||||
|
|
||||||
# Classic - AutoGPT
|
# Classic - AutoGPT
|
||||||
!classic/original_autogpt/autogpt/
|
!classic/original_autogpt/autogpt/
|
||||||
@@ -64,6 +35,38 @@
|
|||||||
# Classic - Frontend
|
# Classic - Frontend
|
||||||
!classic/frontend/build/web/
|
!classic/frontend/build/web/
|
||||||
|
|
||||||
# Explicitly re-ignore some folders
|
# Explicitly re-ignore unwanted files from whitelisted directories
|
||||||
.*
|
# Note: These patterns MUST come after the whitelist rules to take effect
|
||||||
**/__pycache__
|
|
||||||
|
# Hidden files and directories (but keep frontend .env files needed for build)
|
||||||
|
**/.*
|
||||||
|
!autogpt_platform/frontend/.env
|
||||||
|
!autogpt_platform/frontend/.env.default
|
||||||
|
!autogpt_platform/frontend/.env.production
|
||||||
|
|
||||||
|
# Python artifacts
|
||||||
|
**/__pycache__/
|
||||||
|
**/*.pyc
|
||||||
|
**/*.pyo
|
||||||
|
**/.venv/
|
||||||
|
**/.ruff_cache/
|
||||||
|
**/.pytest_cache/
|
||||||
|
**/.coverage
|
||||||
|
**/htmlcov/
|
||||||
|
|
||||||
|
# Node artifacts
|
||||||
|
**/node_modules/
|
||||||
|
**/.next/
|
||||||
|
**/storybook-static/
|
||||||
|
**/playwright-report/
|
||||||
|
**/test-results/
|
||||||
|
|
||||||
|
# Build artifacts
|
||||||
|
**/dist/
|
||||||
|
**/build/
|
||||||
|
!autogpt_platform/frontend/src/**/build/
|
||||||
|
**/target/
|
||||||
|
|
||||||
|
# Logs and temp files
|
||||||
|
**/*.log
|
||||||
|
**/*.tmp
|
||||||
|
|||||||
1229
.github/scripts/detect_overlaps.py
vendored
Normal file
1229
.github/scripts/detect_overlaps.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
42
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
42
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
@@ -40,6 +40,48 @@ jobs:
|
|||||||
git checkout -b "$BRANCH_NAME"
|
git checkout -b "$BRANCH_NAME"
|
||||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
# Backend Python/Poetry setup (so Claude can run linting/tests)
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Set up Python dependency cache
|
||||||
|
uses: actions/cache@v5
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pypoetry
|
||||||
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
run: |
|
||||||
|
cd autogpt_platform/backend
|
||||||
|
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||||
|
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
working-directory: autogpt_platform/backend
|
||||||
|
run: poetry install
|
||||||
|
|
||||||
|
- name: Generate Prisma Client
|
||||||
|
working-directory: autogpt_platform/backend
|
||||||
|
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
|
# Frontend Node.js/pnpm setup (so Claude can run linting/tests)
|
||||||
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v6
|
||||||
|
with:
|
||||||
|
node-version: "22"
|
||||||
|
cache: "pnpm"
|
||||||
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
|
|
||||||
|
- name: Install JavaScript dependencies
|
||||||
|
working-directory: autogpt_platform/frontend
|
||||||
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Get CI failure details
|
- name: Get CI failure details
|
||||||
id: failure_details
|
id: failure_details
|
||||||
uses: actions/github-script@v8
|
uses: actions/github-script@v8
|
||||||
|
|||||||
22
.github/workflows/claude-dependabot.yml
vendored
22
.github/workflows/claude-dependabot.yml
vendored
@@ -77,27 +77,15 @@ jobs:
|
|||||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
cache: "pnpm"
|
||||||
- name: Enable corepack
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set pnpm store directory
|
|
||||||
run: |
|
|
||||||
pnpm config set store-dir ~/.pnpm-store
|
|
||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
|
||||||
uses: actions/cache@v5
|
|
||||||
with:
|
|
||||||
path: ~/.pnpm-store
|
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install JavaScript dependencies
|
- name: Install JavaScript dependencies
|
||||||
working-directory: autogpt_platform/frontend
|
working-directory: autogpt_platform/frontend
|
||||||
|
|||||||
22
.github/workflows/claude.yml
vendored
22
.github/workflows/claude.yml
vendored
@@ -93,27 +93,15 @@ jobs:
|
|||||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
cache: "pnpm"
|
||||||
- name: Enable corepack
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Set pnpm store directory
|
|
||||||
run: |
|
|
||||||
pnpm config set store-dir ~/.pnpm-store
|
|
||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
|
||||||
uses: actions/cache@v5
|
|
||||||
with:
|
|
||||||
path: ~/.pnpm-store
|
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install JavaScript dependencies
|
- name: Install JavaScript dependencies
|
||||||
working-directory: autogpt_platform/frontend
|
working-directory: autogpt_platform/frontend
|
||||||
|
|||||||
4
.github/workflows/codeql.yml
vendored
4
.github/workflows/codeql.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
|||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
uses: github/codeql-action/init@v3
|
uses: github/codeql-action/init@v4
|
||||||
with:
|
with:
|
||||||
languages: ${{ matrix.language }}
|
languages: ${{ matrix.language }}
|
||||||
build-mode: ${{ matrix.build-mode }}
|
build-mode: ${{ matrix.build-mode }}
|
||||||
@@ -93,6 +93,6 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
- name: Perform CodeQL Analysis
|
- name: Perform CodeQL Analysis
|
||||||
uses: github/codeql-action/analyze@v3
|
uses: github/codeql-action/analyze@v4
|
||||||
with:
|
with:
|
||||||
category: "/language:${{matrix.language}}"
|
category: "/language:${{matrix.language}}"
|
||||||
|
|||||||
34
.github/workflows/docs-claude-review.yml
vendored
34
.github/workflows/docs-claude-review.yml
vendored
@@ -7,6 +7,10 @@ on:
|
|||||||
- "docs/integrations/**"
|
- "docs/integrations/**"
|
||||||
- "autogpt_platform/backend/backend/blocks/**"
|
- "autogpt_platform/backend/backend/blocks/**"
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: claude-docs-review-${{ github.event.pull_request.number }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
claude-review:
|
claude-review:
|
||||||
# Only run for PRs from members/collaborators
|
# Only run for PRs from members/collaborators
|
||||||
@@ -91,5 +95,35 @@ jobs:
|
|||||||
3. Read corresponding documentation files to verify accuracy
|
3. Read corresponding documentation files to verify accuracy
|
||||||
4. Provide your feedback as a PR comment
|
4. Provide your feedback as a PR comment
|
||||||
|
|
||||||
|
## IMPORTANT: Comment Marker
|
||||||
|
Start your PR comment with exactly this HTML comment marker on its own line:
|
||||||
|
<!-- CLAUDE_DOCS_REVIEW -->
|
||||||
|
|
||||||
|
This marker is used to identify and replace your comment on subsequent runs.
|
||||||
|
|
||||||
Be constructive and specific. If everything looks good, say so!
|
Be constructive and specific. If everything looks good, say so!
|
||||||
If there are issues, explain what's wrong and suggest how to fix it.
|
If there are issues, explain what's wrong and suggest how to fix it.
|
||||||
|
|
||||||
|
- name: Delete old Claude review comments
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
run: |
|
||||||
|
# Get all comment IDs with our marker, sorted by creation date (oldest first)
|
||||||
|
COMMENT_IDS=$(gh api \
|
||||||
|
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
|
||||||
|
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
|
||||||
|
|
||||||
|
# Count comments
|
||||||
|
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
|
||||||
|
|
||||||
|
if [ "$COMMENT_COUNT" -gt 1 ]; then
|
||||||
|
# Delete all but the last (newest) comment
|
||||||
|
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
|
||||||
|
if [ -n "$COMMENT_ID" ]; then
|
||||||
|
echo "Deleting old review comment: $COMMENT_ID"
|
||||||
|
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
else
|
||||||
|
echo "No old review comments to clean up"
|
||||||
|
fi
|
||||||
|
|||||||
9
.github/workflows/platform-backend-ci.yml
vendored
9
.github/workflows/platform-backend-ci.yml
vendored
@@ -41,13 +41,18 @@ jobs:
|
|||||||
ports:
|
ports:
|
||||||
- 6379:6379
|
- 6379:6379
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
image: rabbitmq:3.12-management
|
image: rabbitmq:4.1.4
|
||||||
ports:
|
ports:
|
||||||
- 5672:5672
|
- 5672:5672
|
||||||
- 15672:15672
|
|
||||||
env:
|
env:
|
||||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||||
|
options: >-
|
||||||
|
--health-cmd "rabbitmq-diagnostics -q ping"
|
||||||
|
--health-interval 30s
|
||||||
|
--health-timeout 10s
|
||||||
|
--health-retries 5
|
||||||
|
--health-start-period 10s
|
||||||
clamav:
|
clamav:
|
||||||
image: clamav/clamav-debian:latest
|
image: clamav/clamav-debian:latest
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
247
.github/workflows/platform-frontend-ci.yml
vendored
247
.github/workflows/platform-frontend-ci.yml
vendored
@@ -6,10 +6,16 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
|
- "autogpt_platform/backend/Dockerfile"
|
||||||
|
- "autogpt_platform/docker-compose.yml"
|
||||||
|
- "autogpt_platform/docker-compose.platform.yml"
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
|
- "autogpt_platform/backend/Dockerfile"
|
||||||
|
- "autogpt_platform/docker-compose.yml"
|
||||||
|
- "autogpt_platform/docker-compose.platform.yml"
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
@@ -26,7 +32,6 @@ jobs:
|
|||||||
setup:
|
setup:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
|
||||||
components-changed: ${{ steps.filter.outputs.components }}
|
components-changed: ${{ steps.filter.outputs.components }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -41,28 +46,17 @@ jobs:
|
|||||||
components:
|
components:
|
||||||
- 'autogpt_platform/frontend/src/components/**'
|
- 'autogpt_platform/frontend/src/components/**'
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Generate cache key
|
- name: Set up Node
|
||||||
id: cache-key
|
uses: actions/setup-node@v6
|
||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Cache dependencies
|
|
||||||
uses: actions/cache@v5
|
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies to populate cache
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
@@ -73,22 +67,15 @@ jobs:
|
|||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Set up Node
|
||||||
uses: actions/cache@v5
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
@@ -111,22 +98,15 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Set up Node
|
||||||
uses: actions/cache@v5
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
@@ -141,10 +121,8 @@ jobs:
|
|||||||
exitOnceUploaded: true
|
exitOnceUploaded: true
|
||||||
|
|
||||||
e2e_test:
|
e2e_test:
|
||||||
|
name: end-to-end tests
|
||||||
runs-on: big-boi
|
runs-on: big-boi
|
||||||
needs: setup
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -152,19 +130,11 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Platform - Copy default supabase .env
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Copy default supabase .env
|
|
||||||
run: |
|
run: |
|
||||||
cp ../.env.default ../.env
|
cp ../.env.default ../.env
|
||||||
|
|
||||||
- name: Copy backend .env and set OpenAI API key
|
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||||
run: |
|
run: |
|
||||||
cp ../backend/.env.default ../backend/.env
|
cp ../backend/.env.default ../backend/.env
|
||||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||||
@@ -172,77 +142,125 @@ jobs:
|
|||||||
# Used by E2E test data script to generate embeddings for approved store agents
|
# Used by E2E test data script to generate embeddings for approved store agents
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Platform - Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
with:
|
||||||
|
driver: docker-container
|
||||||
|
driver-opts: network=host
|
||||||
|
|
||||||
- name: Cache Docker layers
|
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||||
|
uses: crazy-max/ghaction-github-runtime@v4
|
||||||
|
|
||||||
|
- name: Set up Platform - Build Docker images (with cache)
|
||||||
|
working-directory: autogpt_platform
|
||||||
|
run: |
|
||||||
|
pip install pyyaml
|
||||||
|
|
||||||
|
# Resolve extends and generate a flat compose file that bake can understand
|
||||||
|
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||||
|
|
||||||
|
# Add cache configuration to the resolved compose file
|
||||||
|
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||||
|
--source docker-compose.resolved.yml \
|
||||||
|
--cache-from "type=gha" \
|
||||||
|
--cache-to "type=gha,mode=max" \
|
||||||
|
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||||
|
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||||
|
--git-ref "${{ github.ref }}"
|
||||||
|
|
||||||
|
# Build with bake using the resolved compose file (now includes cache config)
|
||||||
|
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||||
|
env:
|
||||||
|
NEXT_PUBLIC_PW_TEST: true
|
||||||
|
|
||||||
|
- name: Set up tests - Cache E2E test data
|
||||||
|
id: e2e-data-cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: /tmp/.buildx-cache
|
path: /tmp/e2e_test_data.sql
|
||||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-buildx-frontend-test-
|
|
||||||
|
|
||||||
- name: Run docker compose
|
- name: Set up Platform - Start Supabase DB + Auth
|
||||||
run: |
|
run: |
|
||||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||||
|
echo "Waiting for database to be ready..."
|
||||||
|
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||||
|
echo "Waiting for auth service to be ready..."
|
||||||
|
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||||
|
|
||||||
|
- name: Set up Platform - Run migrations
|
||||||
|
run: |
|
||||||
|
echo "Running migrations..."
|
||||||
|
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||||
|
echo "✅ Migrations completed"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: 1
|
NEXT_PUBLIC_PW_TEST: true
|
||||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
|
||||||
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
|
||||||
|
|
||||||
- name: Move cache
|
- name: Set up tests - Load cached E2E test data
|
||||||
|
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||||
run: |
|
run: |
|
||||||
rm -rf /tmp/.buildx-cache
|
echo "✅ Found cached E2E test data, restoring..."
|
||||||
if [ -d "/tmp/.buildx-cache-new" ]; then
|
{
|
||||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
echo "SET session_replication_role = 'replica';"
|
||||||
fi
|
cat /tmp/e2e_test_data.sql
|
||||||
|
echo "SET session_replication_role = 'origin';"
|
||||||
|
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||||
|
# Refresh materialized views after restore
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||||
|
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||||
|
|
||||||
- name: Wait for services to be ready
|
echo "✅ E2E test data restored from cache"
|
||||||
|
|
||||||
|
- name: Set up Platform - Start (all other services)
|
||||||
run: |
|
run: |
|
||||||
|
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||||
echo "Waiting for rest_server to be ready..."
|
echo "Waiting for rest_server to be ready..."
|
||||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||||
echo "Waiting for database to be ready..."
|
env:
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
NEXT_PUBLIC_PW_TEST: true
|
||||||
|
|
||||||
- name: Create E2E test data
|
- name: Set up tests - Create E2E test data
|
||||||
|
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||||
run: |
|
run: |
|
||||||
echo "Creating E2E test data..."
|
echo "Creating E2E test data..."
|
||||||
# First try to run the script from inside the container
|
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||||
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||||
echo "✅ Found e2e_test_data.py in container, running it..."
|
echo "❌ E2E test data creation failed!"
|
||||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||||
echo "❌ E2E test data creation failed!"
|
exit 1
|
||||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
}
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
else
|
|
||||||
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
|
||||||
# Copy the script into the container and run it
|
|
||||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
|
||||||
echo "❌ Failed to copy script to container"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
|
||||||
echo "❌ E2E test data creation failed!"
|
|
||||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||||
uses: actions/cache@v5
|
echo "Dumping database for cache..."
|
||||||
|
{
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||||
|
pg_dump -U postgres --data-only --column-inserts \
|
||||||
|
--table='auth.users' postgres
|
||||||
|
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||||
|
pg_dump -U postgres --data-only --column-inserts \
|
||||||
|
--schema=platform \
|
||||||
|
--exclude-table='platform._prisma_migrations' \
|
||||||
|
--exclude-table='platform.apscheduler_jobs' \
|
||||||
|
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||||
|
postgres
|
||||||
|
} > /tmp/e2e_test_data.sql
|
||||||
|
|
||||||
|
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||||
|
|
||||||
|
- name: Set up tests - Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Set up tests - Set up Node
|
||||||
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Set up tests - Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Install Browser 'chromium'
|
- name: Set up tests - Install browser 'chromium'
|
||||||
run: pnpm playwright install --with-deps chromium
|
run: pnpm playwright install --with-deps chromium
|
||||||
|
|
||||||
- name: Run Playwright tests
|
- name: Run Playwright tests
|
||||||
@@ -269,7 +287,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Print Final Docker Compose logs
|
- name: Print Final Docker Compose logs
|
||||||
if: always()
|
if: always()
|
||||||
run: docker compose -f ../docker-compose.yml logs
|
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||||
|
|
||||||
integration_test:
|
integration_test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -281,22 +299,15 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v6
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
- name: Enable corepack
|
||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Set up Node
|
||||||
uses: actions/cache@v5
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
node-version: "22.18.0"
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
cache: "pnpm"
|
||||||
restore-keys: |
|
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|||||||
39
.github/workflows/pr-overlap-check.yml
vendored
Normal file
39
.github/workflows/pr-overlap-check.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
name: PR Overlap Detection
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened]
|
||||||
|
branches:
|
||||||
|
- dev
|
||||||
|
- master
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-overlaps:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # Need full history for merge testing
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Configure git
|
||||||
|
run: |
|
||||||
|
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
|
||||||
|
- name: Run overlap detection
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
# Always succeed - this check informs contributors, it shouldn't block merging
|
||||||
|
continue-on-error: true
|
||||||
|
run: |
|
||||||
|
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}
|
||||||
195
.github/workflows/scripts/docker-ci-fix-compose-build-cache.py
vendored
Normal file
195
.github/workflows/scripts/docker-ci-fix-compose-build-cache.py
vendored
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Add cache configuration to a resolved docker-compose file for all services
|
||||||
|
that have a build key, and ensure image names match what docker compose expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_BRANCH = "dev"
|
||||||
|
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Add cache config to a resolved compose file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--source",
|
||||||
|
required=True,
|
||||||
|
help="Source compose file to read (should be output of `docker compose config`)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-from",
|
||||||
|
default="type=gha",
|
||||||
|
help="Cache source configuration",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-to",
|
||||||
|
default="type=gha,mode=max",
|
||||||
|
help="Cache destination configuration",
|
||||||
|
)
|
||||||
|
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{component}-hash",
|
||||||
|
default="",
|
||||||
|
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--git-ref",
|
||||||
|
default="",
|
||||||
|
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
|
||||||
|
git_ref_scope = ""
|
||||||
|
if args.git_ref:
|
||||||
|
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
|
||||||
|
|
||||||
|
with open(args.source, "r") as f:
|
||||||
|
compose = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Get project name from compose file or default
|
||||||
|
project_name = compose.get("name", "autogpt_platform")
|
||||||
|
|
||||||
|
def get_image_name(dockerfile: str, target: str) -> str:
|
||||||
|
"""Generate image name based on Dockerfile folder and build target."""
|
||||||
|
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
|
||||||
|
if len(dockerfile_parts) >= 2:
|
||||||
|
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
|
||||||
|
else:
|
||||||
|
folder_name = "app"
|
||||||
|
return f"{project_name}-{folder_name}:{target}"
|
||||||
|
|
||||||
|
def get_build_key(dockerfile: str, target: str) -> str:
|
||||||
|
"""Generate a unique key for a Dockerfile+target combination."""
|
||||||
|
return f"{dockerfile}:{target}"
|
||||||
|
|
||||||
|
def get_component(dockerfile: str) -> str | None:
|
||||||
|
"""Get component name (frontend/backend) from dockerfile path."""
|
||||||
|
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||||
|
if component in dockerfile:
|
||||||
|
return component
|
||||||
|
return None
|
||||||
|
|
||||||
|
# First pass: collect all services with build configs and identify duplicates
|
||||||
|
# Track which (dockerfile, target) combinations we've seen
|
||||||
|
build_key_to_first_service: dict[str, str] = {}
|
||||||
|
services_to_build: list[str] = []
|
||||||
|
services_to_dedupe: list[str] = []
|
||||||
|
|
||||||
|
for service_name, service_config in compose.get("services", {}).items():
|
||||||
|
if "build" not in service_config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
build_config = service_config["build"]
|
||||||
|
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||||
|
target = build_config.get("target", "default")
|
||||||
|
build_key = get_build_key(dockerfile, target)
|
||||||
|
|
||||||
|
if build_key not in build_key_to_first_service:
|
||||||
|
# First service with this build config - it will do the actual build
|
||||||
|
build_key_to_first_service[build_key] = service_name
|
||||||
|
services_to_build.append(service_name)
|
||||||
|
else:
|
||||||
|
# Duplicate - will just use the image from the first service
|
||||||
|
services_to_dedupe.append(service_name)
|
||||||
|
|
||||||
|
# Second pass: configure builds and deduplicate
|
||||||
|
modified_services = []
|
||||||
|
for service_name, service_config in compose.get("services", {}).items():
|
||||||
|
if "build" not in service_config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
build_config = service_config["build"]
|
||||||
|
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||||
|
target = build_config.get("target", "latest")
|
||||||
|
image_name = get_image_name(dockerfile, target)
|
||||||
|
|
||||||
|
# Set image name for all services (needed for both builders and deduped)
|
||||||
|
service_config["image"] = image_name
|
||||||
|
|
||||||
|
if service_name in services_to_dedupe:
|
||||||
|
# Remove build config - this service will use the pre-built image
|
||||||
|
del service_config["build"]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# This service will do the actual build - add cache config
|
||||||
|
cache_from_list = []
|
||||||
|
cache_to_list = []
|
||||||
|
|
||||||
|
component = get_component(dockerfile)
|
||||||
|
if not component:
|
||||||
|
# Skip services that don't clearly match frontend/backend
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the hash for this component
|
||||||
|
component_hash = getattr(args, f"{component}_hash")
|
||||||
|
|
||||||
|
# Scope format: platform-{component}-{target}-{hash|ref}
|
||||||
|
# Example: platform-backend-server-abc123
|
||||||
|
|
||||||
|
if "type=gha" in args.cache_from:
|
||||||
|
# 1. Primary: exact hash match (most specific)
|
||||||
|
if component_hash:
|
||||||
|
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||||
|
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
|
||||||
|
|
||||||
|
# 2. Fallback: branch-based cache
|
||||||
|
if git_ref_scope:
|
||||||
|
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||||
|
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
|
||||||
|
|
||||||
|
# 3. Fallback: dev branch cache (for PRs/feature branches)
|
||||||
|
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
|
||||||
|
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
|
||||||
|
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
|
||||||
|
|
||||||
|
if "type=gha" in args.cache_to:
|
||||||
|
# Write to both hash-based and branch-based scopes
|
||||||
|
if component_hash:
|
||||||
|
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||||
|
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
|
||||||
|
|
||||||
|
if git_ref_scope:
|
||||||
|
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||||
|
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
|
||||||
|
|
||||||
|
# Ensure we have at least one cache source/target
|
||||||
|
if not cache_from_list:
|
||||||
|
cache_from_list.append(args.cache_from)
|
||||||
|
if not cache_to_list:
|
||||||
|
cache_to_list.append(args.cache_to)
|
||||||
|
|
||||||
|
build_config["cache_from"] = cache_from_list
|
||||||
|
build_config["cache_to"] = cache_to_list
|
||||||
|
modified_services.append(service_name)
|
||||||
|
|
||||||
|
# Write back to the same file
|
||||||
|
with open(args.source, "w") as f:
|
||||||
|
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
|
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
|
||||||
|
for svc in modified_services:
|
||||||
|
svc_config = compose["services"][svc]
|
||||||
|
build_cfg = svc_config.get("build", {})
|
||||||
|
cache_from_list = build_cfg.get("cache_from", ["none"])
|
||||||
|
cache_to_list = build_cfg.get("cache_to", ["none"])
|
||||||
|
print(f" - {svc}")
|
||||||
|
print(f" image: {svc_config.get('image', 'N/A')}")
|
||||||
|
print(f" cache_from: {cache_from_list}")
|
||||||
|
print(f" cache_to: {cache_to_list}")
|
||||||
|
if services_to_dedupe:
|
||||||
|
print(
|
||||||
|
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
|
||||||
|
)
|
||||||
|
for svc in services_to_dedupe:
|
||||||
|
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -180,4 +180,6 @@ autogpt_platform/backend/settings.py
|
|||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
.next
|
.next
|
||||||
|
# Implementation plans (generated by AI agents)
|
||||||
|
plans/
|
||||||
|
|||||||
@@ -1,3 +1,10 @@
|
|||||||
|
default_install_hook_types:
|
||||||
|
- pre-commit
|
||||||
|
- pre-push
|
||||||
|
- post-checkout
|
||||||
|
|
||||||
|
default_stages: [pre-commit]
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.4.0
|
rev: v4.4.0
|
||||||
@@ -17,6 +24,7 @@ repos:
|
|||||||
name: Detect secrets
|
name: Detect secrets
|
||||||
description: Detects high entropy strings that are likely to be passwords.
|
description: Detects high entropy strings that are likely to be passwords.
|
||||||
files: ^autogpt_platform/
|
files: ^autogpt_platform/
|
||||||
|
exclude: pnpm-lock\.yaml$
|
||||||
stages: [pre-push]
|
stages: [pre-push]
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
@@ -26,49 +34,106 @@ repos:
|
|||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - AutoGPT Platform - Backend
|
name: Check & Install dependencies - AutoGPT Platform - Backend
|
||||||
alias: poetry-install-platform-backend
|
alias: poetry-install-platform-backend
|
||||||
entry: poetry -C autogpt_platform/backend install
|
|
||||||
# include autogpt_libs source (since it's a path dependency)
|
# include autogpt_libs source (since it's a path dependency)
|
||||||
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
|
entry: >
|
||||||
types: [file]
|
bash -c '
|
||||||
|
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||||
|
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||||
|
else
|
||||||
|
git diff --cached --name-only
|
||||||
|
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
|
||||||
|
poetry -C autogpt_platform/backend install
|
||||||
|
'
|
||||||
|
always_run: true
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
stages: [pre-commit, post-checkout]
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - AutoGPT Platform - Libs
|
name: Check & Install dependencies - AutoGPT Platform - Libs
|
||||||
alias: poetry-install-platform-libs
|
alias: poetry-install-platform-libs
|
||||||
entry: poetry -C autogpt_platform/autogpt_libs install
|
entry: >
|
||||||
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
|
bash -c '
|
||||||
types: [file]
|
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||||
|
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||||
|
else
|
||||||
|
git diff --cached --name-only
|
||||||
|
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
|
||||||
|
poetry -C autogpt_platform/autogpt_libs install
|
||||||
|
'
|
||||||
|
always_run: true
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
stages: [pre-commit, post-checkout]
|
||||||
|
|
||||||
|
- id: pnpm-install
|
||||||
|
name: Check & Install dependencies - AutoGPT Platform - Frontend
|
||||||
|
alias: pnpm-install-platform-frontend
|
||||||
|
entry: >
|
||||||
|
bash -c '
|
||||||
|
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||||
|
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||||
|
else
|
||||||
|
git diff --cached --name-only
|
||||||
|
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
|
||||||
|
pnpm --prefix autogpt_platform/frontend install
|
||||||
|
'
|
||||||
|
always_run: true
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
stages: [pre-commit, post-checkout]
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic - AutoGPT
|
name: Check & Install dependencies - Classic - AutoGPT
|
||||||
alias: poetry-install-classic-autogpt
|
alias: poetry-install-classic-autogpt
|
||||||
entry: poetry -C classic/original_autogpt install
|
entry: >
|
||||||
|
bash -c '
|
||||||
|
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||||
|
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||||
|
else
|
||||||
|
git diff --cached --name-only
|
||||||
|
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
||||||
|
poetry -C classic/original_autogpt install
|
||||||
|
'
|
||||||
# include forge source (since it's a path dependency)
|
# include forge source (since it's a path dependency)
|
||||||
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
always_run: true
|
||||||
types: [file]
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
stages: [pre-commit, post-checkout]
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic - Forge
|
name: Check & Install dependencies - Classic - Forge
|
||||||
alias: poetry-install-classic-forge
|
alias: poetry-install-classic-forge
|
||||||
entry: poetry -C classic/forge install
|
entry: >
|
||||||
files: ^classic/forge/poetry\.lock$
|
bash -c '
|
||||||
types: [file]
|
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||||
|
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||||
|
else
|
||||||
|
git diff --cached --name-only
|
||||||
|
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
||||||
|
poetry -C classic/forge install
|
||||||
|
'
|
||||||
|
always_run: true
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
stages: [pre-commit, post-checkout]
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic - Benchmark
|
name: Check & Install dependencies - Classic - Benchmark
|
||||||
alias: poetry-install-classic-benchmark
|
alias: poetry-install-classic-benchmark
|
||||||
entry: poetry -C classic/benchmark install
|
entry: >
|
||||||
files: ^classic/benchmark/poetry\.lock$
|
bash -c '
|
||||||
types: [file]
|
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||||
|
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||||
|
else
|
||||||
|
git diff --cached --name-only
|
||||||
|
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
||||||
|
poetry -C classic/benchmark install
|
||||||
|
'
|
||||||
|
always_run: true
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
stages: [pre-commit, post-checkout]
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
# For proper type checking, Prisma client must be up-to-date.
|
# For proper type checking, Prisma client must be up-to-date.
|
||||||
@@ -76,12 +141,54 @@ repos:
|
|||||||
- id: prisma-generate
|
- id: prisma-generate
|
||||||
name: Prisma Generate - AutoGPT Platform - Backend
|
name: Prisma Generate - AutoGPT Platform - Backend
|
||||||
alias: prisma-generate-platform-backend
|
alias: prisma-generate-platform-backend
|
||||||
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
|
entry: >
|
||||||
|
bash -c '
|
||||||
|
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||||
|
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
||||||
|
else
|
||||||
|
git diff --cached --name-only
|
||||||
|
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
|
||||||
|
cd autogpt_platform/backend
|
||||||
|
&& poetry run prisma generate
|
||||||
|
&& poetry run gen-prisma-stub
|
||||||
|
'
|
||||||
# include everything that triggers poetry install + the prisma schema
|
# include everything that triggers poetry install + the prisma schema
|
||||||
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
|
always_run: true
|
||||||
types: [file]
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
stages: [pre-commit, post-checkout]
|
||||||
|
|
||||||
|
- id: export-api-schema
|
||||||
|
name: Export API schema - AutoGPT Platform - Backend -> Frontend
|
||||||
|
alias: export-api-schema-platform
|
||||||
|
entry: >
|
||||||
|
bash -c '
|
||||||
|
cd autogpt_platform/backend
|
||||||
|
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
||||||
|
&& cd ../frontend
|
||||||
|
&& pnpm prettier --write ./src/app/api/openapi.json
|
||||||
|
'
|
||||||
|
files: ^autogpt_platform/backend/
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: generate-api-client
|
||||||
|
name: Generate API client - AutoGPT Platform - Frontend
|
||||||
|
alias: generate-api-client-platform-frontend
|
||||||
|
entry: >
|
||||||
|
bash -c '
|
||||||
|
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
|
||||||
|
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
||||||
|
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
|
||||||
|
else
|
||||||
|
git diff --quiet HEAD -- "$SCHEMA" && exit 0
|
||||||
|
fi;
|
||||||
|
cd autogpt_platform/frontend && pnpm generate:api
|
||||||
|
'
|
||||||
|
always_run: true
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
stages: [pre-commit, post-checkout]
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.7.2
|
rev: v0.7.2
|
||||||
|
|||||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,2 +1,3 @@
|
|||||||
*.ignore.*
|
*.ignore.*
|
||||||
*.ign.*
|
*.ign.*
|
||||||
|
.application.logs
|
||||||
|
|||||||
@@ -45,6 +45,11 @@ AutoGPT Platform is a monorepo containing:
|
|||||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||||
|
|
||||||
|
### Branching Strategy
|
||||||
|
|
||||||
|
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||||
|
- **`master`** is the production branch. Only used for production releases.
|
||||||
|
|
||||||
### Creating Pull Requests
|
### Creating Pull Requests
|
||||||
|
|
||||||
- Create the PR against the `dev` branch of the repository.
|
- Create the PR against the `dev` branch of the repository.
|
||||||
@@ -55,9 +60,12 @@ AutoGPT Platform is a monorepo containing:
|
|||||||
|
|
||||||
### Reviewing/Revising Pull Requests
|
### Reviewing/Revising Pull Requests
|
||||||
|
|
||||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
||||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
|
||||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
When fetching comments manually:
|
||||||
|
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
|
||||||
|
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
|
||||||
|
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
||||||
|
|
||||||
### Conventional Commits
|
### Conventional Commits
|
||||||
|
|
||||||
|
|||||||
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
40
autogpt_platform/analytics/queries/auth_activities.sql
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.auth_activities
|
||||||
|
-- Looker source alias: ds49 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Tracks authentication events (login, logout, SSO, password
|
||||||
|
-- reset, etc.) from Supabase's internal audit log.
|
||||||
|
-- Useful for monitoring sign-in patterns and detecting anomalies.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.audit_log_entries — Supabase internal auth event log
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- created_at TIMESTAMPTZ When the auth event occurred
|
||||||
|
-- actor_id TEXT User ID who triggered the event
|
||||||
|
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
|
||||||
|
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Rolling 90 days from current date
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Daily login counts
|
||||||
|
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
|
||||||
|
-- FROM analytics.auth_activities
|
||||||
|
-- WHERE action = 'login'
|
||||||
|
-- GROUP BY 1 ORDER BY 1;
|
||||||
|
--
|
||||||
|
-- -- SSO vs password login breakdown
|
||||||
|
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
|
||||||
|
-- WHERE action = 'login' GROUP BY 1;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
created_at,
|
||||||
|
payload->>'actor_id' AS actor_id,
|
||||||
|
payload->>'actor_via_sso' AS actor_via_sso,
|
||||||
|
payload->>'action' AS action
|
||||||
|
FROM auth.audit_log_entries
|
||||||
|
WHERE created_at >= NOW() - INTERVAL '90 days'
|
||||||
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
105
autogpt_platform/analytics/queries/graph_execution.sql
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.graph_execution
|
||||||
|
-- Looker source alias: ds16 | Charts: 21
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per agent graph execution (last 90 days).
|
||||||
|
-- Unpacks the JSONB stats column into individual numeric columns
|
||||||
|
-- and normalises the executionStatus — runs that failed due to
|
||||||
|
-- insufficient credits are reclassified as 'NO_CREDITS' for
|
||||||
|
-- easier filtering. Error messages are scrubbed of IDs and URLs
|
||||||
|
-- to allow safe grouping.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentGraphExecution — Execution records
|
||||||
|
-- platform.AgentGraph — Agent graph metadata (for name)
|
||||||
|
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- id TEXT Execution UUID
|
||||||
|
-- agentGraphId TEXT Agent graph UUID
|
||||||
|
-- agentGraphVersion INT Graph version number
|
||||||
|
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
|
||||||
|
-- createdAt TIMESTAMPTZ When the execution was queued
|
||||||
|
-- updatedAt TIMESTAMPTZ Last status update time
|
||||||
|
-- userId TEXT Owner user UUID
|
||||||
|
-- agentGraphName TEXT Human-readable agent name
|
||||||
|
-- cputime DECIMAL Total CPU seconds consumed
|
||||||
|
-- walltime DECIMAL Total wall-clock seconds
|
||||||
|
-- node_count DECIMAL Number of nodes in the graph
|
||||||
|
-- nodes_cputime DECIMAL CPU time across all nodes
|
||||||
|
-- nodes_walltime DECIMAL Wall time across all nodes
|
||||||
|
-- execution_cost DECIMAL Credit cost of this execution
|
||||||
|
-- correctness_score FLOAT AI correctness score (if available)
|
||||||
|
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
|
||||||
|
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Daily execution counts by status
|
||||||
|
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
|
||||||
|
-- FROM analytics.graph_execution
|
||||||
|
-- GROUP BY 1, 2 ORDER BY 1;
|
||||||
|
--
|
||||||
|
-- -- Average cost per execution by agent
|
||||||
|
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
|
||||||
|
-- FROM analytics.graph_execution
|
||||||
|
-- WHERE "executionStatus" = 'COMPLETED'
|
||||||
|
-- GROUP BY 1 ORDER BY avg_cost DESC;
|
||||||
|
--
|
||||||
|
-- -- Top error messages
|
||||||
|
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
|
||||||
|
-- FROM analytics.graph_execution
|
||||||
|
-- WHERE "executionStatus" = 'FAILED'
|
||||||
|
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
ge."id" AS id,
|
||||||
|
ge."agentGraphId" AS agentGraphId,
|
||||||
|
ge."agentGraphVersion" AS agentGraphVersion,
|
||||||
|
CASE
|
||||||
|
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
|
||||||
|
AND (
|
||||||
|
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
|
||||||
|
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
|
||||||
|
)
|
||||||
|
THEN 'NO_CREDITS'
|
||||||
|
ELSE CAST(ge."executionStatus" AS TEXT)
|
||||||
|
END AS executionStatus,
|
||||||
|
ge."createdAt" AS createdAt,
|
||||||
|
ge."updatedAt" AS updatedAt,
|
||||||
|
ge."userId" AS userId,
|
||||||
|
g."name" AS agentGraphName,
|
||||||
|
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
|
||||||
|
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
|
||||||
|
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
|
||||||
|
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
|
||||||
|
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
|
||||||
|
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
|
||||||
|
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
|
||||||
|
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
|
||||||
|
REGEXP_REPLACE(
|
||||||
|
REGEXP_REPLACE(
|
||||||
|
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
|
||||||
|
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||||
|
'\1\2/...', 'gi'
|
||||||
|
),
|
||||||
|
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||||
|
) AS groupedErrorMessage
|
||||||
|
FROM platform."AgentGraphExecution" ge
|
||||||
|
LEFT JOIN platform."AgentGraph" g
|
||||||
|
ON ge."agentGraphId" = g."id"
|
||||||
|
AND ge."agentGraphVersion" = g."version"
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT DISTINCT ON ("userId", "agentGraphId")
|
||||||
|
"userId", "agentGraphId",
|
||||||
|
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
|
||||||
|
FROM platform."LibraryAgent"
|
||||||
|
WHERE "isDeleted" = FALSE
|
||||||
|
AND "isArchived" = FALSE
|
||||||
|
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
|
||||||
|
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
|
||||||
|
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||||
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
101
autogpt_platform/analytics/queries/node_block_execution.sql
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.node_block_execution
|
||||||
|
-- Looker source alias: ds14 | Charts: 11
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per node (block) execution (last 90 days).
|
||||||
|
-- Unpacks stats JSONB and joins to identify which block type
|
||||||
|
-- was run. For failed nodes, joins the error output and
|
||||||
|
-- scrubs it for safe grouping.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentNodeExecution — Node execution records
|
||||||
|
-- platform.AgentNode — Node → block mapping
|
||||||
|
-- platform.AgentBlock — Block name/ID
|
||||||
|
-- platform.AgentNodeExecutionInputOutput — Error output values
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- id TEXT Node execution UUID
|
||||||
|
-- agentGraphExecutionId TEXT Parent graph execution UUID
|
||||||
|
-- agentNodeId TEXT Node UUID within the graph
|
||||||
|
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
|
||||||
|
-- addedTime TIMESTAMPTZ When the node was queued
|
||||||
|
-- queuedTime TIMESTAMPTZ When it entered the queue
|
||||||
|
-- startedTime TIMESTAMPTZ When execution started
|
||||||
|
-- endedTime TIMESTAMPTZ When execution finished
|
||||||
|
-- inputSize BIGINT Input payload size in bytes
|
||||||
|
-- outputSize BIGINT Output payload size in bytes
|
||||||
|
-- walltime NUMERIC Wall-clock seconds for this node
|
||||||
|
-- cputime NUMERIC CPU seconds for this node
|
||||||
|
-- llmRetryCount INT Number of LLM retries
|
||||||
|
-- llmCallCount INT Number of LLM API calls made
|
||||||
|
-- inputTokenCount BIGINT LLM input tokens consumed
|
||||||
|
-- outputTokenCount BIGINT LLM output tokens produced
|
||||||
|
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
|
||||||
|
-- blockId TEXT Block UUID
|
||||||
|
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
|
||||||
|
-- errorMessage TEXT Raw error output (only set when FAILED)
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Most-used blocks by execution count
|
||||||
|
-- SELECT "blockName", COUNT(*) AS executions,
|
||||||
|
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
|
||||||
|
-- FROM analytics.node_block_execution
|
||||||
|
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
|
||||||
|
--
|
||||||
|
-- -- Average LLM token usage per block
|
||||||
|
-- SELECT "blockName",
|
||||||
|
-- AVG("inputTokenCount") AS avg_input_tokens,
|
||||||
|
-- AVG("outputTokenCount") AS avg_output_tokens
|
||||||
|
-- FROM analytics.node_block_execution
|
||||||
|
-- WHERE "llmCallCount" > 0
|
||||||
|
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
|
||||||
|
--
|
||||||
|
-- -- Top failure reasons
|
||||||
|
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
|
||||||
|
-- FROM analytics.node_block_execution
|
||||||
|
-- WHERE "executionStatus" = 'FAILED'
|
||||||
|
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
ne."id" AS id,
|
||||||
|
ne."agentGraphExecutionId" AS agentGraphExecutionId,
|
||||||
|
ne."agentNodeId" AS agentNodeId,
|
||||||
|
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
|
||||||
|
ne."addedTime" AS addedTime,
|
||||||
|
ne."queuedTime" AS queuedTime,
|
||||||
|
ne."startedTime" AS startedTime,
|
||||||
|
ne."endedTime" AS endedTime,
|
||||||
|
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
|
||||||
|
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
|
||||||
|
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
|
||||||
|
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
|
||||||
|
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
|
||||||
|
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
|
||||||
|
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
|
||||||
|
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
|
||||||
|
b."name" AS blockName,
|
||||||
|
b."id" AS blockId,
|
||||||
|
REGEXP_REPLACE(
|
||||||
|
REGEXP_REPLACE(
|
||||||
|
TRIM(BOTH '"' FROM eio."data"::text),
|
||||||
|
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
||||||
|
'\1\2/...', 'gi'
|
||||||
|
),
|
||||||
|
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
||||||
|
) AS groupedErrorMessage,
|
||||||
|
eio."data" AS errorMessage
|
||||||
|
FROM platform."AgentNodeExecution" ne
|
||||||
|
LEFT JOIN platform."AgentNode" nd
|
||||||
|
ON ne."agentNodeId" = nd."id"
|
||||||
|
LEFT JOIN platform."AgentBlock" b
|
||||||
|
ON nd."agentBlockId" = b."id"
|
||||||
|
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
|
||||||
|
ON eio."referencedByOutputExecId" = ne."id"
|
||||||
|
AND eio."name" = 'error'
|
||||||
|
AND ne."executionStatus" = 'FAILED'
|
||||||
|
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'
|
||||||
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
97
autogpt_platform/analytics/queries/retention_agent.sql
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_agent
|
||||||
|
-- Looker source alias: ds35 | Charts: 2
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Weekly cohort retention broken down per individual agent.
|
||||||
|
-- Cohort = week of a user's first use of THAT specific agent.
|
||||||
|
-- Tells you which agents keep users coming back vs. one-shot
|
||||||
|
-- use. Only includes cohorts from the last 180 days.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentGraphExecution — Execution records (user × agent × time)
|
||||||
|
-- platform.AgentGraph — Agent names
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- agent_id TEXT Agent graph UUID
|
||||||
|
-- agent_label TEXT 'AgentName [first8chars]'
|
||||||
|
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
|
||||||
|
-- cohort_week_start DATE Week users first ran this agent
|
||||||
|
-- cohort_label TEXT ISO week label
|
||||||
|
-- cohort_label_n TEXT ISO week label with cohort size
|
||||||
|
-- user_lifetime_week INT Weeks since first use of this agent
|
||||||
|
-- cohort_users BIGINT Users in this cohort for this agent
|
||||||
|
-- active_users BIGINT Users who ran the agent again in week k
|
||||||
|
-- retention_rate FLOAT active_users / cohort_users
|
||||||
|
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
|
||||||
|
-- agent_total_users BIGINT Total users across all cohorts for this agent
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Best-retained agents at week 2
|
||||||
|
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
|
||||||
|
-- FROM analytics.retention_agent
|
||||||
|
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
|
||||||
|
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
|
||||||
|
--
|
||||||
|
-- -- Agents with most unique users
|
||||||
|
-- SELECT DISTINCT agent_label, agent_total_users
|
||||||
|
-- FROM analytics.retention_agent
|
||||||
|
-- ORDER BY agent_total_users DESC LIMIT 20;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||||
|
events AS (
|
||||||
|
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
|
||||||
|
e."createdAt"::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||||
|
FROM platform."AgentGraphExecution" e
|
||||||
|
),
|
||||||
|
first_use AS (
|
||||||
|
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
|
||||||
|
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||||
|
FROM events GROUP BY 1,2
|
||||||
|
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||||
|
),
|
||||||
|
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
|
||||||
|
user_week_age AS (
|
||||||
|
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
|
||||||
|
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
|
||||||
|
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
|
||||||
|
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
|
||||||
|
),
|
||||||
|
active_counts AS (
|
||||||
|
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
|
||||||
|
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
|
||||||
|
),
|
||||||
|
cohort_sizes AS (
|
||||||
|
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_weeks FROM params),
|
||||||
|
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||||
|
),
|
||||||
|
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
|
||||||
|
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
|
||||||
|
SELECT
|
||||||
|
g.agent_id,
|
||||||
|
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
|
||||||
|
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
|
||||||
|
g.cohort_week_start,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_week, g.cohort_users,
|
||||||
|
COALESCE(ac.active_users,0) AS active_users,
|
||||||
|
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
|
||||||
|
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
|
||||||
|
COALESCE(atu.agent_total_users,0) AS agent_total_users
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
|
||||||
|
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
|
||||||
|
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
|
||||||
|
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_execution_daily
|
||||||
|
-- Looker source alias: ds111 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Daily cohort retention based on agent executions.
|
||||||
|
-- Cohort anchor = day of user's FIRST ever execution.
|
||||||
|
-- Only includes cohorts from the last 90 days, up to day 30.
|
||||||
|
-- Great for early engagement analysis (did users run another
|
||||||
|
-- agent the next day?).
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentGraphExecution — Execution records
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- Same pattern as retention_login_daily.
|
||||||
|
-- cohort_day_start = day of first execution (not first login)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Day-3 execution retention
|
||||||
|
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
|
||||||
|
-- FROM analytics.retention_execution_daily
|
||||||
|
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
|
||||||
|
events AS (
|
||||||
|
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('day', e."createdAt")::date AS day_start
|
||||||
|
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||||
|
),
|
||||||
|
first_exec AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||||
|
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||||
|
),
|
||||||
|
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||||
|
user_day_age AS (
|
||||||
|
SELECT ad.user_id, fe.cohort_day_start,
|
||||||
|
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
|
||||||
|
FROM activity_days ad JOIN first_exec fe USING (user_id)
|
||||||
|
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_day_start,
|
||||||
|
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_day, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||||
|
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_execution_weekly
|
||||||
|
-- Looker source alias: ds92 | Charts: 2
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Weekly cohort retention based on agent executions.
|
||||||
|
-- Cohort anchor = week of user's FIRST ever agent execution
|
||||||
|
-- (not first login). Only includes cohorts from the last 180 days.
|
||||||
|
-- Useful when you care about product engagement, not just visits.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.AgentGraphExecution — Execution records
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- Same pattern as retention_login_weekly.
|
||||||
|
-- cohort_week_start = week of first execution (not first login)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Week-2 execution retention
|
||||||
|
-- SELECT cohort_label, retention_rate_bounded
|
||||||
|
-- FROM analytics.retention_execution_weekly
|
||||||
|
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
||||||
|
events AS (
|
||||||
|
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
||||||
|
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
||||||
|
),
|
||||||
|
first_exec AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_exec_at,
|
||||||
|
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||||
|
),
|
||||||
|
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||||
|
user_week_age AS (
|
||||||
|
SELECT aw.user_id, fe.cohort_week_start,
|
||||||
|
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
|
||||||
|
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
|
||||||
|
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_weeks FROM params),
|
||||||
|
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_week_start,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_week, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||||
|
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||||
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
94
autogpt_platform/analytics/queries/retention_login_daily.sql
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_login_daily
|
||||||
|
-- Looker source alias: ds112 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Daily cohort retention based on login sessions.
|
||||||
|
-- Same logic as retention_login_weekly but at day granularity,
|
||||||
|
-- showing up to day 30 for cohorts from the last 90 days.
|
||||||
|
-- Useful for analysing early activation (days 1-7) in detail.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.sessions — Login session records
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
|
||||||
|
-- cohort_day_start DATE First day the cohort logged in
|
||||||
|
-- cohort_label TEXT Date string (e.g. '2025-03-01')
|
||||||
|
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
|
||||||
|
-- user_lifetime_day INT Days since first login (0 = signup day)
|
||||||
|
-- cohort_users BIGINT Total users in cohort
|
||||||
|
-- active_users_bounded BIGINT Users active on exactly day k
|
||||||
|
-- retained_users_unbounded BIGINT Users active any time on/after day k
|
||||||
|
-- retention_rate_bounded FLOAT bounded / cohort_users
|
||||||
|
-- retention_rate_unbounded FLOAT unbounded / cohort_users
|
||||||
|
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Day-1 retention rate (came back next day)
|
||||||
|
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
|
||||||
|
-- FROM analytics.retention_login_daily
|
||||||
|
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
|
||||||
|
--
|
||||||
|
-- -- Average retention curve across all cohorts
|
||||||
|
-- SELECT user_lifetime_day,
|
||||||
|
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
|
||||||
|
-- FROM analytics.retention_login_daily
|
||||||
|
-- GROUP BY 1 ORDER BY 1;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
|
||||||
|
events AS (
|
||||||
|
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('day', s.created_at)::date AS day_start
|
||||||
|
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||||
|
),
|
||||||
|
first_login AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||||
|
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
||||||
|
),
|
||||||
|
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
||||||
|
user_day_age AS (
|
||||||
|
SELECT ad.user_id, fl.cohort_day_start,
|
||||||
|
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
|
||||||
|
FROM activity_days ad JOIN first_login fl USING (user_id)
|
||||||
|
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_day_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_day_start,
|
||||||
|
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_day, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
||||||
|
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_login_onboarded_weekly
|
||||||
|
-- Looker source alias: ds101 | Charts: 2
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Weekly cohort retention from login sessions, restricted to
|
||||||
|
-- users who "onboarded" — defined as running at least one
|
||||||
|
-- agent within 365 days of their first login.
|
||||||
|
-- Filters out users who signed up but never activated,
|
||||||
|
-- giving a cleaner view of engaged-user retention.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.sessions — Login session records
|
||||||
|
-- platform.AgentGraphExecution — Used to identify onboarders
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
|
||||||
|
-- retention_rate_bounded, retention_rate_unbounded, etc.)
|
||||||
|
-- Only difference: cohort is filtered to onboarded users only.
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Compare week-4 retention: all users vs onboarded only
|
||||||
|
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
|
||||||
|
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
|
||||||
|
-- UNION ALL
|
||||||
|
-- SELECT 'onboarded', AVG(retention_rate_bounded)
|
||||||
|
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
|
||||||
|
events AS (
|
||||||
|
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||||
|
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||||
|
),
|
||||||
|
first_login_all AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||||
|
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
),
|
||||||
|
onboarders AS (
|
||||||
|
SELECT fl.user_id FROM first_login_all fl
|
||||||
|
WHERE EXISTS (
|
||||||
|
SELECT 1 FROM platform."AgentGraphExecution" e
|
||||||
|
WHERE e."userId"::text = fl.user_id
|
||||||
|
AND e."createdAt" >= fl.first_login_time
|
||||||
|
AND e."createdAt" < fl.first_login_time
|
||||||
|
+ make_interval(days => (SELECT onboarding_window_days FROM params))
|
||||||
|
)
|
||||||
|
),
|
||||||
|
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
|
||||||
|
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||||
|
user_week_age AS (
|
||||||
|
SELECT aw.user_id, fl.cohort_week_start,
|
||||||
|
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
|
||||||
|
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||||
|
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_weeks FROM params),
|
||||||
|
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_week_start,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_week, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||||
|
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
||||||
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
103
autogpt_platform/analytics/queries/retention_login_weekly.sql
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.retention_login_weekly
|
||||||
|
-- Looker source alias: ds83 | Charts: 2
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Weekly cohort retention based on login sessions.
|
||||||
|
-- Users are grouped by the ISO week of their first ever login.
|
||||||
|
-- For each cohort × lifetime-week combination, outputs both:
|
||||||
|
-- - bounded rate: % active in exactly that week
|
||||||
|
-- - unbounded rate: % who were ever active on or after that week
|
||||||
|
-- Weeks are capped to the cohort's actual age (no future data points).
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.sessions — Login session records
|
||||||
|
--
|
||||||
|
-- HOW TO READ THE OUTPUT
|
||||||
|
-- cohort_week_start The Monday of the week users first logged in
|
||||||
|
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
|
||||||
|
-- retention_rate_bounded = active_users_bounded / cohort_users
|
||||||
|
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- cohort_week_start DATE First day of the cohort's signup week
|
||||||
|
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
|
||||||
|
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
|
||||||
|
-- user_lifetime_week INT Weeks since first login (0 = signup week)
|
||||||
|
-- cohort_users BIGINT Total users in this cohort (denominator)
|
||||||
|
-- active_users_bounded BIGINT Users active in exactly week k
|
||||||
|
-- retained_users_unbounded BIGINT Users active any time on/after week k
|
||||||
|
-- retention_rate_bounded FLOAT bounded active / cohort_users
|
||||||
|
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
|
||||||
|
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Week-1 retention rate per cohort
|
||||||
|
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
|
||||||
|
-- FROM analytics.retention_login_weekly
|
||||||
|
-- WHERE user_lifetime_week = 1
|
||||||
|
-- ORDER BY cohort_week_start;
|
||||||
|
--
|
||||||
|
-- -- Overall average retention curve (all cohorts combined)
|
||||||
|
-- SELECT user_lifetime_week,
|
||||||
|
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
|
||||||
|
-- FROM analytics.retention_login_weekly
|
||||||
|
-- GROUP BY 1 ORDER BY 1;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH params AS (SELECT 12::int AS max_weeks),
|
||||||
|
events AS (
|
||||||
|
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
||||||
|
DATE_TRUNC('week', s.created_at)::date AS week_start
|
||||||
|
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
||||||
|
),
|
||||||
|
first_login AS (
|
||||||
|
SELECT user_id, MIN(created_at) AS first_login_time,
|
||||||
|
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
||||||
|
FROM events GROUP BY 1
|
||||||
|
),
|
||||||
|
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
||||||
|
user_week_age AS (
|
||||||
|
SELECT aw.user_id, fl.cohort_week_start,
|
||||||
|
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
|
||||||
|
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
||||||
|
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
|
||||||
|
),
|
||||||
|
bounded_counts AS (
|
||||||
|
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
||||||
|
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
||||||
|
),
|
||||||
|
last_active AS (
|
||||||
|
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
||||||
|
),
|
||||||
|
unbounded_counts AS (
|
||||||
|
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
||||||
|
FROM last_active la
|
||||||
|
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
||||||
|
GROUP BY 1,2
|
||||||
|
),
|
||||||
|
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
||||||
|
cohort_caps AS (
|
||||||
|
SELECT cs.cohort_week_start, cs.cohort_users,
|
||||||
|
LEAST((SELECT max_weeks FROM params),
|
||||||
|
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
|
||||||
|
FROM cohort_sizes cs
|
||||||
|
),
|
||||||
|
grid AS (
|
||||||
|
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
||||||
|
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
g.cohort_week_start,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
||||||
|
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
||||||
|
g.user_lifetime_week, g.cohort_users,
|
||||||
|
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
||||||
|
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
||||||
|
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
||||||
|
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
||||||
|
FROM grid g
|
||||||
|
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
||||||
|
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
||||||
|
ORDER BY g.cohort_week_start, g.user_lifetime_week
|
||||||
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
71
autogpt_platform/analytics/queries/user_block_spending.sql
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.user_block_spending
|
||||||
|
-- Looker source alias: ds6 | Charts: 5
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per credit transaction (last 90 days).
|
||||||
|
-- Shows how users spend credits broken down by block type,
|
||||||
|
-- LLM provider and model. Joins node execution stats for
|
||||||
|
-- token-level detail.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.CreditTransaction — Credit debit/credit records
|
||||||
|
-- platform.AgentNodeExecution — Node execution stats (for token counts)
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- transactionKey TEXT Unique transaction identifier
|
||||||
|
-- userId TEXT User who was charged
|
||||||
|
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
|
||||||
|
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
|
||||||
|
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
|
||||||
|
-- transactionTime TIMESTAMPTZ When the transaction was recorded
|
||||||
|
-- blockId TEXT Block UUID that triggered the spend
|
||||||
|
-- blockName TEXT Human-readable block name
|
||||||
|
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
|
||||||
|
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
|
||||||
|
-- node_exec_id TEXT Linked node execution UUID
|
||||||
|
-- llm_call_count INT LLM API calls made in that execution
|
||||||
|
-- llm_retry_count INT LLM retries in that execution
|
||||||
|
-- llm_input_token_count INT Input tokens consumed
|
||||||
|
-- llm_output_token_count INT Output tokens produced
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Total spend per user (last 90 days)
|
||||||
|
-- SELECT "userId", SUM("negativeAmount") AS total_spent
|
||||||
|
-- FROM analytics.user_block_spending
|
||||||
|
-- WHERE "transactionType" = 'USAGE'
|
||||||
|
-- GROUP BY 1 ORDER BY total_spent DESC;
|
||||||
|
--
|
||||||
|
-- -- Spend by LLM provider + model
|
||||||
|
-- SELECT "llm_provider", "llm_model",
|
||||||
|
-- SUM("negativeAmount") AS total_cost,
|
||||||
|
-- SUM("llm_input_token_count") AS input_tokens,
|
||||||
|
-- SUM("llm_output_token_count") AS output_tokens
|
||||||
|
-- FROM analytics.user_block_spending
|
||||||
|
-- WHERE "llm_provider" IS NOT NULL
|
||||||
|
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
c."transactionKey" AS transactionKey,
|
||||||
|
c."userId" AS userId,
|
||||||
|
c."amount" AS amount,
|
||||||
|
c."amount" * -1 AS negativeAmount,
|
||||||
|
c."type" AS transactionType,
|
||||||
|
c."createdAt" AS transactionTime,
|
||||||
|
c.metadata->>'block_id' AS blockId,
|
||||||
|
c.metadata->>'block' AS blockName,
|
||||||
|
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
|
||||||
|
c.metadata->'input'->>'model' AS llm_model,
|
||||||
|
c.metadata->>'node_exec_id' AS node_exec_id,
|
||||||
|
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
|
||||||
|
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
|
||||||
|
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
|
||||||
|
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
|
||||||
|
FROM platform."CreditTransaction" c
|
||||||
|
LEFT JOIN platform."AgentNodeExecution" ne
|
||||||
|
ON (c.metadata->>'node_exec_id') = ne."id"::text
|
||||||
|
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||||
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
45
autogpt_platform/analytics/queries/user_onboarding.sql
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.user_onboarding
|
||||||
|
-- Looker source alias: ds68 | Charts: 3
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per user onboarding record. Contains the user's
|
||||||
|
-- stated usage reason, selected integrations, completed
|
||||||
|
-- onboarding steps and optional first agent selection.
|
||||||
|
-- Full history (no date filter) since onboarding happens
|
||||||
|
-- once per user.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.UserOnboarding — Onboarding state per user
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- id TEXT Onboarding record UUID
|
||||||
|
-- createdAt TIMESTAMPTZ When onboarding started
|
||||||
|
-- updatedAt TIMESTAMPTZ Last update to onboarding state
|
||||||
|
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
|
||||||
|
-- integrations TEXT[] Array of integration names the user selected
|
||||||
|
-- userId TEXT User UUID
|
||||||
|
-- completedSteps TEXT[] Array of onboarding step enums completed
|
||||||
|
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Usage reason breakdown
|
||||||
|
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
|
||||||
|
--
|
||||||
|
-- -- Completion rate per step
|
||||||
|
-- SELECT step, COUNT(*) AS users_completed
|
||||||
|
-- FROM analytics.user_onboarding
|
||||||
|
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
|
||||||
|
-- GROUP BY 1 ORDER BY users_completed DESC;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
"createdAt",
|
||||||
|
"updatedAt",
|
||||||
|
"usageReason",
|
||||||
|
integrations,
|
||||||
|
"userId",
|
||||||
|
"completedSteps",
|
||||||
|
"selectedStoreListingVersionId"
|
||||||
|
FROM platform."UserOnboarding"
|
||||||
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
100
autogpt_platform/analytics/queries/user_onboarding_funnel.sql
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.user_onboarding_funnel
|
||||||
|
-- Looker source alias: ds74 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Pre-aggregated onboarding funnel showing how many users
|
||||||
|
-- completed each step and the drop-off percentage from the
|
||||||
|
-- previous step. One row per onboarding step (all 22 steps
|
||||||
|
-- always present, even with 0 completions — prevents sparse
|
||||||
|
-- gaps from making LAG compare the wrong predecessors).
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.UserOnboarding — Onboarding records with completedSteps array
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
|
||||||
|
-- step_order INT Numeric position in the funnel (1=first, 22=last)
|
||||||
|
-- users_completed BIGINT Distinct users who completed this step
|
||||||
|
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
|
||||||
|
--
|
||||||
|
-- STEP ORDER
|
||||||
|
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
|
||||||
|
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
|
||||||
|
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
|
||||||
|
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
|
||||||
|
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
|
||||||
|
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
|
||||||
|
-- 7 CONGRATS 15 VISIT_COPILOT
|
||||||
|
-- 8 GET_RESULTS 16 RE_RUN_AGENT
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Users who started onboarding in the last 90 days
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Full funnel
|
||||||
|
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
|
||||||
|
--
|
||||||
|
-- -- Biggest drop-off point
|
||||||
|
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
|
||||||
|
-- ORDER BY pct_from_prev ASC LIMIT 3;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH all_steps AS (
|
||||||
|
-- Complete ordered grid of all 22 steps so zero-completion steps
|
||||||
|
-- are always present, keeping LAG comparisons correct.
|
||||||
|
SELECT step_name, step_order
|
||||||
|
FROM (VALUES
|
||||||
|
('WELCOME', 1),
|
||||||
|
('USAGE_REASON', 2),
|
||||||
|
('INTEGRATIONS', 3),
|
||||||
|
('AGENT_CHOICE', 4),
|
||||||
|
('AGENT_NEW_RUN', 5),
|
||||||
|
('AGENT_INPUT', 6),
|
||||||
|
('CONGRATS', 7),
|
||||||
|
('GET_RESULTS', 8),
|
||||||
|
('MARKETPLACE_VISIT', 9),
|
||||||
|
('MARKETPLACE_ADD_AGENT', 10),
|
||||||
|
('MARKETPLACE_RUN_AGENT', 11),
|
||||||
|
('BUILDER_OPEN', 12),
|
||||||
|
('BUILDER_SAVE_AGENT', 13),
|
||||||
|
('BUILDER_RUN_AGENT', 14),
|
||||||
|
('VISIT_COPILOT', 15),
|
||||||
|
('RE_RUN_AGENT', 16),
|
||||||
|
('SCHEDULE_AGENT', 17),
|
||||||
|
('RUN_AGENTS', 18),
|
||||||
|
('RUN_3_DAYS', 19),
|
||||||
|
('TRIGGER_WEBHOOK', 20),
|
||||||
|
('RUN_14_DAYS', 21),
|
||||||
|
('RUN_AGENTS_100', 22)
|
||||||
|
) AS t(step_name, step_order)
|
||||||
|
),
|
||||||
|
raw AS (
|
||||||
|
SELECT
|
||||||
|
u."userId",
|
||||||
|
step_txt::text AS step
|
||||||
|
FROM platform."UserOnboarding" u
|
||||||
|
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
|
||||||
|
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||||
|
),
|
||||||
|
step_counts AS (
|
||||||
|
SELECT step, COUNT(DISTINCT "userId") AS users_completed
|
||||||
|
FROM raw GROUP BY step
|
||||||
|
),
|
||||||
|
funnel AS (
|
||||||
|
SELECT
|
||||||
|
a.step_name AS step,
|
||||||
|
a.step_order,
|
||||||
|
COALESCE(sc.users_completed, 0) AS users_completed,
|
||||||
|
ROUND(
|
||||||
|
100.0 * COALESCE(sc.users_completed, 0)
|
||||||
|
/ NULLIF(
|
||||||
|
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
|
||||||
|
0
|
||||||
|
),
|
||||||
|
2
|
||||||
|
) AS pct_from_prev
|
||||||
|
FROM all_steps a
|
||||||
|
LEFT JOIN step_counts sc ON sc.step = a.step_name
|
||||||
|
)
|
||||||
|
SELECT * FROM funnel ORDER BY step_order
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.user_onboarding_integration
|
||||||
|
-- Looker source alias: ds75 | Charts: 1
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- Pre-aggregated count of users who selected each integration
|
||||||
|
-- during onboarding. One row per integration type, sorted
|
||||||
|
-- by popularity.
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- platform.UserOnboarding — integrations array column
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
|
||||||
|
-- users_with_integration BIGINT Distinct users who selected this integration
|
||||||
|
--
|
||||||
|
-- WINDOW
|
||||||
|
-- Users who started onboarding in the last 90 days
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Full integration popularity ranking
|
||||||
|
-- SELECT * FROM analytics.user_onboarding_integration;
|
||||||
|
--
|
||||||
|
-- -- Top 5 integrations
|
||||||
|
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH exploded AS (
|
||||||
|
SELECT
|
||||||
|
u."userId" AS user_id,
|
||||||
|
UNNEST(u."integrations") AS integration
|
||||||
|
FROM platform."UserOnboarding" u
|
||||||
|
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
integration,
|
||||||
|
COUNT(DISTINCT user_id) AS users_with_integration
|
||||||
|
FROM exploded
|
||||||
|
WHERE integration IS NOT NULL AND integration <> ''
|
||||||
|
GROUP BY integration
|
||||||
|
ORDER BY users_with_integration DESC
|
||||||
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
145
autogpt_platform/analytics/queries/users_activities.sql
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
-- =============================================================
|
||||||
|
-- View: analytics.users_activities
|
||||||
|
-- Looker source alias: ds56 | Charts: 5
|
||||||
|
-- =============================================================
|
||||||
|
-- DESCRIPTION
|
||||||
|
-- One row per user with lifetime activity summary.
|
||||||
|
-- Joins login sessions with agent graphs, executions and
|
||||||
|
-- node-level runs to give a full picture of how engaged
|
||||||
|
-- each user is. Includes a convenience flag for 7-day
|
||||||
|
-- activation (did the user return at least 7 days after
|
||||||
|
-- their first login?).
|
||||||
|
--
|
||||||
|
-- SOURCE TABLES
|
||||||
|
-- auth.sessions — Login/session records
|
||||||
|
-- platform.AgentGraph — Graphs (agents) built by the user
|
||||||
|
-- platform.AgentGraphExecution — Agent run history
|
||||||
|
-- platform.AgentNodeExecution — Individual block execution history
|
||||||
|
--
|
||||||
|
-- PERFORMANCE NOTE
|
||||||
|
-- Each CTE aggregates its own table independently by userId.
|
||||||
|
-- This avoids the fan-out that occurs when driving every join
|
||||||
|
-- from user_logins across the two largest tables
|
||||||
|
-- (AgentGraphExecution and AgentNodeExecution).
|
||||||
|
--
|
||||||
|
-- OUTPUT COLUMNS
|
||||||
|
-- user_id TEXT Supabase user UUID
|
||||||
|
-- first_login_time TIMESTAMPTZ First ever session created_at
|
||||||
|
-- last_login_time TIMESTAMPTZ Most recent session created_at
|
||||||
|
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
|
||||||
|
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
|
||||||
|
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
|
||||||
|
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
|
||||||
|
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
|
||||||
|
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
|
||||||
|
-- agent_runs BIGINT Total graph execution count (0 if none)
|
||||||
|
-- node_execution_count BIGINT Total node executions across all runs
|
||||||
|
-- node_execution_failed BIGINT Node executions with FAILED status
|
||||||
|
-- node_execution_completed BIGINT Node executions with COMPLETED status
|
||||||
|
-- node_execution_terminated BIGINT Node executions with TERMINATED status
|
||||||
|
-- node_execution_queued BIGINT Node executions with QUEUED status
|
||||||
|
-- node_execution_running BIGINT Node executions with RUNNING status
|
||||||
|
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
|
||||||
|
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
|
||||||
|
-- node_execution_review BIGINT Node executions with REVIEW status
|
||||||
|
--
|
||||||
|
-- EXAMPLE QUERIES
|
||||||
|
-- -- Users who ran at least one agent and returned after 7 days
|
||||||
|
-- SELECT COUNT(*) FROM analytics.users_activities
|
||||||
|
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
|
||||||
|
--
|
||||||
|
-- -- Top 10 most active users by agent runs
|
||||||
|
-- SELECT user_id, agent_runs, node_execution_count
|
||||||
|
-- FROM analytics.users_activities
|
||||||
|
-- ORDER BY agent_runs DESC LIMIT 10;
|
||||||
|
--
|
||||||
|
-- -- 7-day activation rate
|
||||||
|
-- SELECT
|
||||||
|
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
|
||||||
|
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
|
||||||
|
-- AS activation_rate
|
||||||
|
-- FROM analytics.users_activities;
|
||||||
|
-- =============================================================
|
||||||
|
|
||||||
|
WITH user_logins AS (
|
||||||
|
SELECT
|
||||||
|
user_id::text AS user_id,
|
||||||
|
MIN(created_at) AS first_login_time,
|
||||||
|
MAX(created_at) AS last_login_time,
|
||||||
|
GREATEST(
|
||||||
|
MAX(refreshed_at)::timestamptz,
|
||||||
|
MAX(created_at)::timestamptz
|
||||||
|
) AS last_visit_time
|
||||||
|
FROM auth.sessions
|
||||||
|
GROUP BY user_id
|
||||||
|
),
|
||||||
|
user_agents AS (
|
||||||
|
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
|
||||||
|
SELECT
|
||||||
|
"userId"::text AS user_id,
|
||||||
|
MAX("updatedAt") AS last_agent_save_time,
|
||||||
|
COUNT(DISTINCT "id") AS agent_count
|
||||||
|
FROM platform."AgentGraph"
|
||||||
|
WHERE "isActive"
|
||||||
|
GROUP BY "userId"
|
||||||
|
),
|
||||||
|
user_graph_runs AS (
|
||||||
|
-- Aggregate AgentGraphExecution directly by userId
|
||||||
|
SELECT
|
||||||
|
"userId"::text AS user_id,
|
||||||
|
MIN("createdAt") AS first_agent_run_time,
|
||||||
|
MAX("createdAt") AS last_agent_run_time,
|
||||||
|
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
|
||||||
|
COUNT("id") AS agent_runs
|
||||||
|
FROM platform."AgentGraphExecution"
|
||||||
|
GROUP BY "userId"
|
||||||
|
),
|
||||||
|
user_node_runs AS (
|
||||||
|
-- Aggregate AgentNodeExecution directly; resolve userId via a
|
||||||
|
-- single join to AgentGraphExecution instead of fanning out from
|
||||||
|
-- user_logins through both large tables.
|
||||||
|
SELECT
|
||||||
|
g."userId"::text AS user_id,
|
||||||
|
COUNT(*) AS node_execution_count,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
|
||||||
|
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
|
||||||
|
FROM platform."AgentNodeExecution" n
|
||||||
|
JOIN platform."AgentGraphExecution" g
|
||||||
|
ON g."id" = n."agentGraphExecutionId"
|
||||||
|
GROUP BY g."userId"
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
ul.user_id,
|
||||||
|
ul.first_login_time,
|
||||||
|
ul.last_login_time,
|
||||||
|
ul.last_visit_time,
|
||||||
|
ua.last_agent_save_time,
|
||||||
|
COALESCE(ua.agent_count, 0) AS agent_count,
|
||||||
|
gr.first_agent_run_time,
|
||||||
|
gr.last_agent_run_time,
|
||||||
|
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
|
||||||
|
COALESCE(gr.agent_runs, 0) AS agent_runs,
|
||||||
|
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
|
||||||
|
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
|
||||||
|
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
|
||||||
|
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
|
||||||
|
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
|
||||||
|
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
|
||||||
|
CASE
|
||||||
|
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||||
|
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
|
||||||
|
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
||||||
|
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
|
||||||
|
ELSE NULL
|
||||||
|
END AS is_active_after_7d,
|
||||||
|
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
|
||||||
|
COALESCE(nr.node_execution_review, 0) AS node_execution_review
|
||||||
|
FROM user_logins ul
|
||||||
|
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
|
||||||
|
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
|
||||||
|
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id
|
||||||
169
autogpt_platform/autogpt_libs/poetry.lock
generated
169
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -448,61 +448,61 @@ toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cryptography"
|
name = "cryptography"
|
||||||
version = "46.0.4"
|
version = "46.0.5"
|
||||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "!=3.9.0,!=3.9.1,>=3.8"
|
python-versions = "!=3.9.0,!=3.9.1,>=3.8"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:281526e865ed4166009e235afadf3a4c4cba6056f99336a99efba65336fd5485"},
|
{file = "cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f14fba5bf6f4390d7ff8f086c566454bff0411f6d8aa7af79c88b6f9267aecc"},
|
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47bcd19517e6389132f76e2d5303ded6cf3f78903da2158a671be8de024f4cd0"},
|
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:01df4f50f314fbe7009f54046e908d1754f19d0c6d3070df1e6268c5a4af09fa"},
|
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5aa3e463596b0087b3da0dbe2b2487e9fc261d25da85754e30e3b40637d61f81"},
|
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0a9ad24359fee86f131836a9ac3bffc9329e956624a2d379b613f8f8abaf5255"},
|
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:dc1272e25ef673efe72f2096e92ae39dea1a1a450dd44918b15351f72c5a168e"},
|
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:de0f5f4ec8711ebc555f54735d4c673fc34b65c44283895f1a08c2b49d2fd99c"},
|
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:eeeb2e33d8dbcccc34d64651f00a98cb41b2dc69cef866771a5717e6734dfa32"},
|
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3d425eacbc9aceafd2cb429e42f4e5d5633c6f873f5e567077043ef1b9bbf616"},
|
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91627ebf691d1ea3976a031b61fb7bac1ccd745afa03602275dda443e11c8de0"},
|
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2d08bc22efd73e8854b0b7caff402d735b354862f1145d7be3b9c0f740fef6a0"},
|
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-win32.whl", hash = "sha256:82a62483daf20b8134f6e92898da70d04d0ef9a75829d732ea1018678185f4f5"},
|
{file = "cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48"},
|
||||||
{file = "cryptography-46.0.4-cp311-abi3-win_amd64.whl", hash = "sha256:6225d3ebe26a55dbc8ead5ad1265c0403552a63336499564675b29eb3184c09b"},
|
{file = "cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:485e2b65d25ec0d901bca7bcae0f53b00133bf3173916d8e421f6fddde103908"},
|
{file = "cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:078e5f06bd2fa5aea5a324f2a09f914b1484f1d0c2a4d6a8a28c74e72f65f2da"},
|
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dce1e4f068f03008da7fa51cc7abc6ddc5e5de3e3d1550334eaf8393982a5829"},
|
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:2067461c80271f422ee7bdbe79b9b4be54a5162e90345f86a23445a0cf3fd8a2"},
|
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:c92010b58a51196a5f41c3795190203ac52edfd5dc3ff99149b4659eba9d2085"},
|
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:829c2b12bbc5428ab02d6b7f7e9bbfd53e33efd6672d21341f2177470171ad8b"},
|
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:62217ba44bf81b30abaeda1488686a04a702a261e26f87db51ff61d9d3510abd"},
|
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:9c2da296c8d3415b93e6053f5a728649a87a48ce084a9aaf51d6e46c87c7f2d2"},
|
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:9b34d8ba84454641a6bf4d6762d15847ecbd85c1316c0a7984e6e4e9f748ec2e"},
|
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:df4a817fa7138dd0c96c8c8c20f04b8aaa1fac3bbf610913dcad8ea82e1bfd3f"},
|
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b1de0ebf7587f28f9190b9cb526e901bf448c9e6a99655d2b07fff60e8212a82"},
|
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9b4d17bc7bd7cdd98e3af40b441feaea4c68225e2eb2341026c84511ad246c0c"},
|
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-win32.whl", hash = "sha256:c411f16275b0dea722d76544a61d6421e2cc829ad76eec79280dbdc9ddf50061"},
|
{file = "cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a"},
|
||||||
{file = "cryptography-46.0.4-cp314-cp314t-win_amd64.whl", hash = "sha256:728fedc529efc1439eb6107b677f7f7558adab4553ef8669f0d02d42d7b959a7"},
|
{file = "cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a9556ba711f7c23f77b151d5798f3ac44a13455cc68db7697a1096e6d0563cab"},
|
{file = "cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bf75b0259e87fa70bddc0b8b4078b76e7fd512fd9afae6c1193bcf440a4dbef"},
|
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3c268a3490df22270955966ba236d6bc4a8f9b6e4ffddb78aac535f1a5ea471d"},
|
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:812815182f6a0c1d49a37893a303b44eaac827d7f0d582cecfc81b6427f22973"},
|
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:a90e43e3ef65e6dcf969dfe3bb40cbf5aef0d523dff95bfa24256be172a845f4"},
|
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a05177ff6296644ef2876fce50518dffb5bcdf903c85250974fc8bc85d54c0af"},
|
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:daa392191f626d50f1b136c9b4cf08af69ca8279d110ea24f5c2700054d2e263"},
|
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e07ea39c5b048e085f15923511d8121e4a9dc45cee4e3b970ca4f0d338f23095"},
|
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:d5a45ddc256f492ce42a4e35879c5e5528c09cd9ad12420828c972951d8e016b"},
|
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:6bb5157bf6a350e5b28aee23beb2d84ae6f5be390b2f8ee7ea179cda077e1019"},
|
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd5aba870a2c40f87a3af043e0dee7d9eb02d4aff88a797b48f2b43eff8c3ab4"},
|
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:93d8291da8d71024379ab2cb0b5c57915300155ad42e07f76bea6ad838d7e59b"},
|
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-win32.whl", hash = "sha256:0563655cb3c6d05fb2afe693340bc050c30f9f34e15763361cf08e94749401fc"},
|
{file = "cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72"},
|
||||||
{file = "cryptography-46.0.4-cp38-abi3-win_amd64.whl", hash = "sha256:fa0900b9ef9c49728887d1576fd8d9e7e3ea872fa9b25ef9b64888adc434e976"},
|
{file = "cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595"},
|
||||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:766330cce7416c92b5e90c3bb71b1b79521760cdcfc3a6a1a182d4c9fab23d2b"},
|
{file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"},
|
||||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c236a44acfb610e70f6b3e1c3ca20ff24459659231ef2f8c48e879e2d32b73da"},
|
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"},
|
||||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8a15fb869670efa8f83cbffbc8753c1abf236883225aed74cd179b720ac9ec80"},
|
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"},
|
||||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:fdc3daab53b212472f1524d070735b2f0c214239df131903bae1d598016fa822"},
|
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"},
|
||||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:44cc0675b27cadb71bdbb96099cca1fa051cd11d2ade09e5cd3a2edb929ed947"},
|
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"},
|
||||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be8c01a7d5a55f9a47d1888162b76c8f49d62b234d88f0ff91a9fbebe32ffbc3"},
|
{file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"},
|
||||||
{file = "cryptography-46.0.4.tar.gz", hash = "sha256:bfd019f60f8abc2ed1b9be4ddc21cfef059c841d86d710bb69909a688cbb8f59"},
|
{file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -516,7 +516,7 @@ nox = ["nox[uv] (>=2024.4.15)"]
|
|||||||
pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"]
|
pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"]
|
||||||
sdist = ["build (>=1.0.0)"]
|
sdist = ["build (>=1.0.0)"]
|
||||||
ssh = ["bcrypt (>=3.1.5)"]
|
ssh = ["bcrypt (>=3.1.5)"]
|
||||||
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.4)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.5)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||||
test-randomorder = ["pytest-randomly"]
|
test-randomorder = ["pytest-randomly"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -570,24 +570,25 @@ tests = ["coverage", "coveralls", "dill", "mock", "nose"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.128.0"
|
version = "0.128.7"
|
||||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d"},
|
{file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"},
|
||||||
{file = "fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a"},
|
{file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
annotated-doc = ">=0.0.2"
|
annotated-doc = ">=0.0.2"
|
||||||
pydantic = ">=2.7.0"
|
pydantic = ">=2.7.0"
|
||||||
starlette = ">=0.40.0,<0.51.0"
|
starlette = ">=0.40.0,<1.0.0"
|
||||||
typing-extensions = ">=4.8.0"
|
typing-extensions = ">=4.8.0"
|
||||||
|
typing-inspection = ">=0.4.2"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.9.3)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=5.8.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
|
|
||||||
@@ -1062,14 +1063,14 @@ urllib3 = ">=1.26.0,<3"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "launchdarkly-server-sdk"
|
name = "launchdarkly-server-sdk"
|
||||||
version = "9.14.1"
|
version = "9.15.0"
|
||||||
description = "LaunchDarkly SDK for Python"
|
description = "LaunchDarkly SDK for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.10"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"},
|
{file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"},
|
||||||
{file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"},
|
{file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -1478,14 +1479,14 @@ testing = ["coverage", "pytest", "pytest-benchmark"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "postgrest"
|
name = "postgrest"
|
||||||
version = "2.27.2"
|
version = "2.28.0"
|
||||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "postgrest-2.27.2-py3-none-any.whl", hash = "sha256:1666fef3de05ca097a314433dd5ae2f2d71c613cb7b233d0f468c4ffe37277da"},
|
{file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"},
|
||||||
{file = "postgrest-2.27.2.tar.gz", hash = "sha256:55407d530b5af3d64e883a71fec1f345d369958f723ce4a8ab0b7d169e313242"},
|
{file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2248,14 +2249,14 @@ cli = ["click (>=5.0)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "realtime"
|
name = "realtime"
|
||||||
version = "2.27.2"
|
version = "2.28.0"
|
||||||
description = ""
|
description = ""
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "realtime-2.27.2-py3-none-any.whl", hash = "sha256:34a9cbb26a274e707e8fc9e3ee0a66de944beac0fe604dc336d1e985db2c830f"},
|
{file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"},
|
||||||
{file = "realtime-2.27.2.tar.gz", hash = "sha256:b960a90294d2cea1b3f1275ecb89204304728e08fff1c393cc1b3150739556b3"},
|
{file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2436,14 +2437,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "storage3"
|
name = "storage3"
|
||||||
version = "2.27.2"
|
version = "2.28.0"
|
||||||
description = "Supabase Storage client for Python."
|
description = "Supabase Storage client for Python."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "storage3-2.27.2-py3-none-any.whl", hash = "sha256:e6f16e7a260729e7b1f46e9bf61746805a02e30f5e419ee1291007c432e3ec63"},
|
{file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"},
|
||||||
{file = "storage3-2.27.2.tar.gz", hash = "sha256:cb4807b7f86b4bb1272ac6fdd2f3cfd8ba577297046fa5f88557425200275af5"},
|
{file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2487,35 +2488,35 @@ python-dateutil = ">=2.6.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase"
|
name = "supabase"
|
||||||
version = "2.27.2"
|
version = "2.28.0"
|
||||||
description = "Supabase client for Python."
|
description = "Supabase client for Python."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase-2.27.2-py3-none-any.whl", hash = "sha256:d4dce00b3a418ee578017ec577c0e5be47a9a636355009c76f20ed2faa15bc54"},
|
{file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"},
|
||||||
{file = "supabase-2.27.2.tar.gz", hash = "sha256:2aed40e4f3454438822442a1e94a47be6694c2c70392e7ae99b51a226d4293f7"},
|
{file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
httpx = ">=0.26,<0.29"
|
httpx = ">=0.26,<0.29"
|
||||||
postgrest = "2.27.2"
|
postgrest = "2.28.0"
|
||||||
realtime = "2.27.2"
|
realtime = "2.28.0"
|
||||||
storage3 = "2.27.2"
|
storage3 = "2.28.0"
|
||||||
supabase-auth = "2.27.2"
|
supabase-auth = "2.28.0"
|
||||||
supabase-functions = "2.27.2"
|
supabase-functions = "2.28.0"
|
||||||
yarl = ">=1.22.0"
|
yarl = ">=1.22.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase-auth"
|
name = "supabase-auth"
|
||||||
version = "2.27.2"
|
version = "2.28.0"
|
||||||
description = "Python Client Library for Supabase Auth"
|
description = "Python Client Library for Supabase Auth"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase_auth-2.27.2-py3-none-any.whl", hash = "sha256:78ec25b11314d0a9527a7205f3b1c72560dccdc11b38392f80297ef98664ee91"},
|
{file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"},
|
||||||
{file = "supabase_auth-2.27.2.tar.gz", hash = "sha256:0f5bcc79b3677cb42e9d321f3c559070cfa40d6a29a67672cc8382fb7dc2fe97"},
|
{file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2525,14 +2526,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase-functions"
|
name = "supabase-functions"
|
||||||
version = "2.27.2"
|
version = "2.28.0"
|
||||||
description = "Library for Supabase Functions"
|
description = "Library for Supabase Functions"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase_functions-2.27.2-py3-none-any.whl", hash = "sha256:db480efc669d0bca07605b9b6f167312af43121adcc842a111f79bea416ef754"},
|
{file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"},
|
||||||
{file = "supabase_functions-2.27.2.tar.gz", hash = "sha256:d0c8266207a94371cb3fd35ad3c7f025b78a97cf026861e04ccd35ac1775f80b"},
|
{file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2911,4 +2912,4 @@ type = ["pytest-mypy"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<4.0"
|
python-versions = ">=3.10,<4.0"
|
||||||
content-hash = "40eae94995dc0a388fa832ed4af9b6137f28d5b5ced3aaea70d5f91d4d9a179d"
|
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||||
|
|||||||
@@ -11,14 +11,14 @@ python = ">=3.10,<4.0"
|
|||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^46.0"
|
cryptography = "^46.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.0"
|
fastapi = "^0.128.7"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.14.1"
|
launchdarkly-server-sdk = "^9.15.0"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.27.2"
|
supabase = "^2.28.0"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
|||||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||||
|
|
||||||
|
## ===== SIGNUP / INVITE GATE ===== ##
|
||||||
|
# Set to true to require an invite before users can sign up
|
||||||
|
ENABLE_INVITE_GATE=false
|
||||||
|
|
||||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||||
# Platform URLs (set these for webhooks and OAuth to work)
|
# Platform URLs (set these for webhooks and OAuth to work)
|
||||||
PLATFORM_BASE_URL=http://localhost:8000
|
PLATFORM_BASE_URL=http://localhost:8000
|
||||||
@@ -104,6 +108,12 @@ TWITTER_CLIENT_SECRET=
|
|||||||
# Make a new workspace for your OAuth APP -- trust me
|
# Make a new workspace for your OAuth APP -- trust me
|
||||||
# https://linear.app/settings/api/applications/new
|
# https://linear.app/settings/api/applications/new
|
||||||
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
||||||
|
LINEAR_API_KEY=
|
||||||
|
# Linear project and team IDs for the feature request tracker.
|
||||||
|
# Find these in your Linear workspace URL: linear.app/<workspace>/project/<project-id>
|
||||||
|
# and in team settings. Used by the chat copilot to file and search feature requests.
|
||||||
|
LINEAR_FEATURE_REQUEST_PROJECT_ID=
|
||||||
|
LINEAR_FEATURE_REQUEST_TEAM_ID=
|
||||||
LINEAR_CLIENT_ID=
|
LINEAR_CLIENT_ID=
|
||||||
LINEAR_CLIENT_SECRET=
|
LINEAR_CLIENT_SECRET=
|
||||||
|
|
||||||
@@ -184,5 +194,8 @@ ZEROBOUNCE_API_KEY=
|
|||||||
POSTHOG_API_KEY=
|
POSTHOG_API_KEY=
|
||||||
POSTHOG_HOST=https://eu.i.posthog.com
|
POSTHOG_HOST=https://eu.i.posthog.com
|
||||||
|
|
||||||
|
# Tally Form Integration (pre-populate business understanding on signup)
|
||||||
|
TALLY_API_KEY=
|
||||||
|
|
||||||
# Other Services
|
# Other Services
|
||||||
AUTOMOD_API_KEY=
|
AUTOMOD_API_KEY=
|
||||||
|
|||||||
@@ -58,10 +58,31 @@ poetry run pytest path/to/test.py --snapshot-update
|
|||||||
- **Authentication**: JWT-based with Supabase integration
|
- **Authentication**: JWT-based with Supabase integration
|
||||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||||
|
|
||||||
|
## Code Style
|
||||||
|
|
||||||
|
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
||||||
|
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
||||||
|
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
||||||
|
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
||||||
|
- **List comprehensions** over manual loop-and-append
|
||||||
|
- **Early return** — guard clauses first, avoid deep nesting
|
||||||
|
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
|
||||||
|
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
||||||
|
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
||||||
|
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
||||||
|
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
||||||
|
- **`max(0, value)` guards** — for computed values that should never be negative
|
||||||
|
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
||||||
|
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
||||||
|
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
||||||
|
|
||||||
## Testing Approach
|
## Testing Approach
|
||||||
|
|
||||||
- Uses pytest with snapshot testing for API responses
|
- Uses pytest with snapshot testing for API responses
|
||||||
- Test files are colocated with source files (`*_test.py`)
|
- Test files are colocated with source files (`*_test.py`)
|
||||||
|
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
||||||
|
- After refactoring, update mock targets to match new module paths
|
||||||
|
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
||||||
|
|
||||||
## Database Schema
|
## Database Schema
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
# ============================ DEPENDENCY BUILDER ============================ #
|
||||||
|
|
||||||
FROM debian:13-slim AS builder
|
FROM debian:13-slim AS builder
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
@@ -51,60 +53,106 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
|||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
FROM debian:13-slim AS server_dependencies
|
# =============================== DB MIGRATOR =============================== #
|
||||||
|
|
||||||
|
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
||||||
|
FROM debian:13-slim AS migrate
|
||||||
|
|
||||||
|
WORKDIR /app/autogpt_platform/backend
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3.13 \
|
||||||
|
python3-pip \
|
||||||
|
ca-certificates \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy Node.js from builder (needed for Prisma CLI)
|
||||||
|
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||||
|
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||||
|
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||||
|
|
||||||
|
# Copy Prisma binaries
|
||||||
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
|
# Install prisma-client-py directly (much smaller than copying full venv)
|
||||||
|
RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||||
|
|
||||||
|
COPY autogpt_platform/backend/schema.prisma ./
|
||||||
|
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||||
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
|
COPY autogpt_platform/backend/migrations ./migrations
|
||||||
|
|
||||||
|
# ============================== BACKEND SERVER ============================== #
|
||||||
|
|
||||||
|
FROM debian:13-slim AS server
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
ENV POETRY_HOME=/opt/poetry \
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
POETRY_NO_INTERACTION=1 \
|
|
||||||
POETRY_VIRTUALENVS_CREATE=true \
|
|
||||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
|
||||||
DEBIAN_FRONTEND=noninteractive
|
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
|
||||||
|
|
||||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||||
RUN apt-get update && apt-get install -y \
|
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||||
|
# for the bash_exec MCP tool (fallback when E2B is not configured).
|
||||||
|
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
imagemagick \
|
imagemagick \
|
||||||
|
jq \
|
||||||
|
ripgrep \
|
||||||
|
tree \
|
||||||
|
bubblewrap \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||||
COPY --from=builder /app /app
|
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||||
# Copy Node.js installation for Prisma
|
# Copy Node.js installation for Prisma and agent-browser.
|
||||||
|
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
|
||||||
|
# COPY resolves them to regular files, breaking require() paths. Recreate as
|
||||||
|
# proper symlinks so npm/npx can find their modules.
|
||||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
||||||
|
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
||||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
RUN mkdir -p /app/autogpt_platform/backend
|
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
||||||
|
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
||||||
|
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
||||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
fonts-liberation libfontconfig1 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& npm install -g agent-browser \
|
||||||
|
&& agent-browser install \
|
||||||
|
&& rm -rf /tmp/* /root/.npm
|
||||||
|
|
||||||
WORKDIR /app/autogpt_platform/backend
|
WORKDIR /app/autogpt_platform/backend
|
||||||
|
|
||||||
FROM server_dependencies AS migrate
|
# Copy only the .venv from builder (not the entire /app directory)
|
||||||
|
# The .venv includes the generated Prisma client
|
||||||
|
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
||||||
|
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||||
|
|
||||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
# Copy dependency files + autogpt_libs (path dependency)
|
||||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
||||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
|
||||||
|
|
||||||
FROM server_dependencies AS server
|
# Copy backend code + docs (for Copilot docs search)
|
||||||
|
COPY autogpt_platform/backend ./
|
||||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
|
||||||
COPY docs /app/docs
|
COPY docs /app/docs
|
||||||
RUN poetry install --no-ansi --only-root
|
# Install the project package to create entry point scripts in .venv/bin/
|
||||||
|
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
|
||||||
|
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||||
|
poetry install --no-ansi --only-root
|
||||||
|
|
||||||
ENV PORT=8000
|
ENV PORT=8000
|
||||||
|
|
||||||
CMD ["poetry", "run", "rest"]
|
CMD ["rest"]
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
"""Common test fixtures for server tests."""
|
"""Common test fixtures for server tests.
|
||||||
|
|
||||||
|
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
||||||
|
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
||||||
|
(backend/conftest.py) and are available here automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
@@ -11,54 +16,6 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
|||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_user_id() -> str:
|
|
||||||
"""Test user ID fixture."""
|
|
||||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def admin_user_id() -> str:
|
|
||||||
"""Admin user ID fixture."""
|
|
||||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def target_user_id() -> str:
|
|
||||||
"""Target user ID fixture."""
|
|
||||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def setup_test_user(test_user_id):
|
|
||||||
"""Create test user in database before tests."""
|
|
||||||
from backend.data.user import get_or_create_user
|
|
||||||
|
|
||||||
# Create the test user in the database using JWT token format
|
|
||||||
user_data = {
|
|
||||||
"sub": test_user_id,
|
|
||||||
"email": "test@example.com",
|
|
||||||
"user_metadata": {"name": "Test User"},
|
|
||||||
}
|
|
||||||
await get_or_create_user(user_data)
|
|
||||||
return test_user_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def setup_admin_user(admin_user_id):
|
|
||||||
"""Create admin user in database before tests."""
|
|
||||||
from backend.data.user import get_or_create_user
|
|
||||||
|
|
||||||
# Create the admin user in the database using JWT token format
|
|
||||||
user_data = {
|
|
||||||
"sub": admin_user_id,
|
|
||||||
"email": "test-admin@example.com",
|
|
||||||
"user_metadata": {"name": "Test Admin"},
|
|
||||||
}
|
|
||||||
await get_or_create_user(user_data)
|
|
||||||
return admin_user_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_jwt_user(test_user_id):
|
def mock_jwt_user(test_user_id):
|
||||||
"""Provide mock JWT payload for regular user testing."""
|
"""Provide mock JWT payload for regular user testing."""
|
||||||
|
|||||||
@@ -88,20 +88,23 @@ async def require_auth(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def require_permission(permission: APIKeyPermission):
|
def require_permission(*permissions: APIKeyPermission):
|
||||||
"""
|
"""
|
||||||
Dependency function for checking specific permissions
|
Dependency function for checking required permissions.
|
||||||
|
All listed permissions must be present.
|
||||||
(works with API keys and OAuth tokens)
|
(works with API keys and OAuth tokens)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def check_permission(
|
async def check_permissions(
|
||||||
auth: APIAuthorizationInfo = Security(require_auth),
|
auth: APIAuthorizationInfo = Security(require_auth),
|
||||||
) -> APIAuthorizationInfo:
|
) -> APIAuthorizationInfo:
|
||||||
if permission not in auth.scopes:
|
missing = [p for p in permissions if p not in auth.scopes]
|
||||||
|
if missing:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Missing required permission: {permission.value}",
|
detail=f"Missing required permission(s): "
|
||||||
|
f"{', '.join(p.value for p in missing)}",
|
||||||
)
|
)
|
||||||
return auth
|
return auth
|
||||||
|
|
||||||
return check_permission
|
return check_permissions
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Annotated, Any, Literal, Optional, Sequence
|
from typing import Annotated, Any, Optional, Sequence
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, HTTPException, Security
|
from fastapi import APIRouter, Body, HTTPException, Security
|
||||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||||
@@ -9,15 +9,17 @@ from pydantic import BaseModel, Field
|
|||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
import backend.api.features.store.cache as store_cache
|
import backend.api.features.store.cache as store_cache
|
||||||
|
import backend.api.features.store.db as store_db
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.blocks
|
import backend.blocks
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_auth, require_permission
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data import user as user_db
|
from backend.data import user as user_db
|
||||||
from backend.data.auth.base import APIAuthorizationInfo
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
|
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from .integrations import integrations_router
|
from .integrations import integrations_router
|
||||||
@@ -95,6 +97,43 @@ async def execute_graph_block(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.post(
|
||||||
|
path="/graphs",
|
||||||
|
tags=["graphs"],
|
||||||
|
status_code=201,
|
||||||
|
dependencies=[
|
||||||
|
Security(
|
||||||
|
require_permission(
|
||||||
|
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def create_graph(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
auth: APIAuthorizationInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
|
||||||
|
),
|
||||||
|
) -> graph_db.GraphModel:
|
||||||
|
"""
|
||||||
|
Create a new agent graph.
|
||||||
|
|
||||||
|
The graph will be validated and assigned a new ID.
|
||||||
|
It is automatically added to the user's library.
|
||||||
|
"""
|
||||||
|
from backend.api.features.library import db as library_db
|
||||||
|
|
||||||
|
graph_model = graph_db.make_graph_model(graph, auth.user_id)
|
||||||
|
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
|
||||||
|
graph_model.validate_graph(for_run=False)
|
||||||
|
|
||||||
|
await graph_db.create_graph(graph_model, user_id=auth.user_id)
|
||||||
|
await library_db.create_library_agent(graph_model, auth.user_id)
|
||||||
|
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
|
||||||
|
|
||||||
|
return activated_graph
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||||
tags=["graphs"],
|
tags=["graphs"],
|
||||||
@@ -192,13 +231,13 @@ async def get_graph_execution_results(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/agents",
|
path="/store/agents",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||||
response_model=store_model.StoreAgentsResponse,
|
response_model=store_model.StoreAgentsResponse,
|
||||||
)
|
)
|
||||||
async def get_store_agents(
|
async def get_store_agents(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
creator: str | None = None,
|
creator: str | None = None,
|
||||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
sorted_by: store_db.StoreAgentsSortOptions | None = None,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
@@ -240,7 +279,7 @@ async def get_store_agents(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/agents/{username}/{agent_name}",
|
path="/store/agents/{username}/{agent_name}",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||||
response_model=store_model.StoreAgentDetails,
|
response_model=store_model.StoreAgentDetails,
|
||||||
)
|
)
|
||||||
async def get_store_agent(
|
async def get_store_agent(
|
||||||
@@ -268,13 +307,13 @@ async def get_store_agent(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/creators",
|
path="/store/creators",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||||
response_model=store_model.CreatorsResponse,
|
response_model=store_model.CreatorsResponse,
|
||||||
)
|
)
|
||||||
async def get_store_creators(
|
async def get_store_creators(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> store_model.CreatorsResponse:
|
) -> store_model.CreatorsResponse:
|
||||||
@@ -310,7 +349,7 @@ async def get_store_creators(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/creators/{username}",
|
path="/store/creators/{username}",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
||||||
response_model=store_model.CreatorDetails,
|
response_model=store_model.CreatorDetails,
|
||||||
)
|
)
|
||||||
async def get_store_creator(
|
async def get_store_creator(
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
||||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
from backend.copilot.tools.models import ToolResponseBase
|
||||||
from backend.data.auth.base import APIAuthorizationInfo
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,8 +1,17 @@
|
|||||||
from pydantic import BaseModel
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||||
|
|
||||||
|
import prisma.enums
|
||||||
|
from pydantic import BaseModel, EmailStr
|
||||||
|
|
||||||
from backend.data.model import UserTransaction
|
from backend.data.model import UserTransaction
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
||||||
|
|
||||||
|
|
||||||
class UserHistoryResponse(BaseModel):
|
class UserHistoryResponse(BaseModel):
|
||||||
"""Response model for listings with version history"""
|
"""Response model for listings with version history"""
|
||||||
@@ -14,3 +23,70 @@ class UserHistoryResponse(BaseModel):
|
|||||||
class AddUserCreditsResponse(BaseModel):
|
class AddUserCreditsResponse(BaseModel):
|
||||||
new_balance: int
|
new_balance: int
|
||||||
transaction_key: str
|
transaction_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class CreateInvitedUserRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class InvitedUserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
status: prisma.enums.InvitedUserStatus
|
||||||
|
auth_user_id: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
tally_understanding: Optional[dict[str, Any]] = None
|
||||||
|
tally_status: prisma.enums.TallyComputationStatus
|
||||||
|
tally_computed_at: Optional[datetime] = None
|
||||||
|
tally_error: Optional[str] = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
||||||
|
return cls.model_validate(record.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
class InvitedUsersResponse(BaseModel):
|
||||||
|
invited_users: list[InvitedUserResponse]
|
||||||
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
|
class BulkInvitedUserRowResponse(BaseModel):
|
||||||
|
row_number: int
|
||||||
|
email: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
||||||
|
message: str
|
||||||
|
invited_user: Optional[InvitedUserResponse] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BulkInvitedUsersResponse(BaseModel):
|
||||||
|
created_count: int
|
||||||
|
skipped_count: int
|
||||||
|
error_count: int
|
||||||
|
results: list[BulkInvitedUserRowResponse]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
||||||
|
return cls(
|
||||||
|
created_count=result.created_count,
|
||||||
|
skipped_count=result.skipped_count,
|
||||||
|
error_count=result.error_count,
|
||||||
|
results=[
|
||||||
|
BulkInvitedUserRowResponse(
|
||||||
|
row_number=row.row_number,
|
||||||
|
email=row.email,
|
||||||
|
name=row.name,
|
||||||
|
status=row.status,
|
||||||
|
message=row.message,
|
||||||
|
invited_user=(
|
||||||
|
InvitedUserResponse.from_record(row.invited_user)
|
||||||
|
if row.invited_user is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for row in result.results
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -24,14 +24,13 @@ router = fastapi.APIRouter(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/listings",
|
"/listings",
|
||||||
summary="Get Admin Listings History",
|
summary="Get Admin Listings History",
|
||||||
response_model=store_model.StoreListingsWithVersionsResponse,
|
|
||||||
)
|
)
|
||||||
async def get_admin_listings_with_versions(
|
async def get_admin_listings_with_versions(
|
||||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||||
search: typing.Optional[str] = None,
|
search: typing.Optional[str] = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
):
|
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
|
||||||
"""
|
"""
|
||||||
Get store listings with their version history for admins.
|
Get store listings with their version history for admins.
|
||||||
|
|
||||||
@@ -45,36 +44,26 @@ async def get_admin_listings_with_versions(
|
|||||||
page_size: Number of items per page
|
page_size: Number of items per page
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StoreListingsWithVersionsResponse with listings and their versions
|
Paginated listings with their versions
|
||||||
"""
|
"""
|
||||||
try:
|
listings = await store_db.get_admin_listings_with_versions(
|
||||||
listings = await store_db.get_admin_listings_with_versions(
|
status=status,
|
||||||
status=status,
|
search_query=search,
|
||||||
search_query=search,
|
page=page,
|
||||||
page=page,
|
page_size=page_size,
|
||||||
page_size=page_size,
|
)
|
||||||
)
|
return listings
|
||||||
return listings
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Error getting admin listings with versions: %s", e)
|
|
||||||
return fastapi.responses.JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={
|
|
||||||
"detail": "An error occurred while retrieving listings with versions"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/submissions/{store_listing_version_id}/review",
|
"/submissions/{store_listing_version_id}/review",
|
||||||
summary="Review Store Submission",
|
summary="Review Store Submission",
|
||||||
response_model=store_model.StoreSubmission,
|
|
||||||
)
|
)
|
||||||
async def review_submission(
|
async def review_submission(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
request: store_model.ReviewSubmissionRequest,
|
request: store_model.ReviewSubmissionRequest,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
):
|
) -> store_model.StoreSubmissionAdminView:
|
||||||
"""
|
"""
|
||||||
Review a store listing submission.
|
Review a store listing submission.
|
||||||
|
|
||||||
@@ -84,31 +73,24 @@ async def review_submission(
|
|||||||
user_id: Authenticated admin user performing the review
|
user_id: Authenticated admin user performing the review
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StoreSubmission with updated review information
|
StoreSubmissionAdminView with updated review information
|
||||||
"""
|
"""
|
||||||
try:
|
already_approved = await store_db.check_submission_already_approved(
|
||||||
already_approved = await store_db.check_submission_already_approved(
|
store_listing_version_id=store_listing_version_id,
|
||||||
store_listing_version_id=store_listing_version_id,
|
)
|
||||||
)
|
submission = await store_db.review_store_submission(
|
||||||
submission = await store_db.review_store_submission(
|
store_listing_version_id=store_listing_version_id,
|
||||||
store_listing_version_id=store_listing_version_id,
|
is_approved=request.is_approved,
|
||||||
is_approved=request.is_approved,
|
external_comments=request.comments,
|
||||||
external_comments=request.comments,
|
internal_comments=request.internal_comments or "",
|
||||||
internal_comments=request.internal_comments or "",
|
reviewer_id=user_id,
|
||||||
reviewer_id=user_id,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
state_changed = already_approved != request.is_approved
|
state_changed = already_approved != request.is_approved
|
||||||
# Clear caches when the request is approved as it updates what is shown on the store
|
# Clear caches whenever approval state changes, since store visibility can change
|
||||||
if state_changed:
|
if state_changed:
|
||||||
store_cache.clear_all_caches()
|
store_cache.clear_all_caches()
|
||||||
return submission
|
return submission
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Error reviewing submission: %s", e)
|
|
||||||
return fastapi.responses.JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={"detail": "An error occurred while reviewing the submission"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
|
from autogpt_libs.auth import get_user_id, requires_admin_user
|
||||||
|
from fastapi import APIRouter, File, Query, Security, UploadFile
|
||||||
|
|
||||||
|
from backend.data.invited_user import (
|
||||||
|
bulk_create_invited_users_from_file,
|
||||||
|
create_invited_user,
|
||||||
|
list_invited_users,
|
||||||
|
retry_invited_user_tally,
|
||||||
|
revoke_invited_user,
|
||||||
|
)
|
||||||
|
from backend.data.tally import mask_email
|
||||||
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
from .model import (
|
||||||
|
BulkInvitedUsersResponse,
|
||||||
|
CreateInvitedUserRequest,
|
||||||
|
InvitedUserResponse,
|
||||||
|
InvitedUsersResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/admin",
|
||||||
|
tags=["users", "admin"],
|
||||||
|
dependencies=[Security(requires_admin_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/invited-users",
|
||||||
|
response_model=InvitedUsersResponse,
|
||||||
|
summary="List Invited Users",
|
||||||
|
)
|
||||||
|
async def get_invited_users(
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(50, ge=1, le=200),
|
||||||
|
) -> InvitedUsersResponse:
|
||||||
|
logger.info("Admin user %s requested invited users", admin_user_id)
|
||||||
|
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
||||||
|
return InvitedUsersResponse(
|
||||||
|
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
||||||
|
pagination=Pagination(
|
||||||
|
total_items=total,
|
||||||
|
total_pages=max(1, math.ceil(total / page_size)),
|
||||||
|
current_page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/invited-users",
|
||||||
|
response_model=InvitedUserResponse,
|
||||||
|
summary="Create Invited User",
|
||||||
|
)
|
||||||
|
async def create_invited_user_route(
|
||||||
|
request: CreateInvitedUserRequest,
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
) -> InvitedUserResponse:
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s creating invited user for %s",
|
||||||
|
admin_user_id,
|
||||||
|
mask_email(request.email),
|
||||||
|
)
|
||||||
|
invited_user = await create_invited_user(request.email, request.name)
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s created invited user %s",
|
||||||
|
admin_user_id,
|
||||||
|
invited_user.id,
|
||||||
|
)
|
||||||
|
return InvitedUserResponse.from_record(invited_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/invited-users/bulk",
|
||||||
|
response_model=BulkInvitedUsersResponse,
|
||||||
|
summary="Bulk Create Invited Users",
|
||||||
|
operation_id="postV2BulkCreateInvitedUsers",
|
||||||
|
)
|
||||||
|
async def bulk_create_invited_users_route(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
) -> BulkInvitedUsersResponse:
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s bulk invited users from %s",
|
||||||
|
admin_user_id,
|
||||||
|
file.filename or "<unnamed>",
|
||||||
|
)
|
||||||
|
content = await file.read()
|
||||||
|
result = await bulk_create_invited_users_from_file(file.filename, content)
|
||||||
|
return BulkInvitedUsersResponse.from_result(result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/invited-users/{invited_user_id}/revoke",
|
||||||
|
response_model=InvitedUserResponse,
|
||||||
|
summary="Revoke Invited User",
|
||||||
|
)
|
||||||
|
async def revoke_invited_user_route(
|
||||||
|
invited_user_id: str,
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
) -> InvitedUserResponse:
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
||||||
|
)
|
||||||
|
invited_user = await revoke_invited_user(invited_user_id)
|
||||||
|
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
||||||
|
return InvitedUserResponse.from_record(invited_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/invited-users/{invited_user_id}/retry-tally",
|
||||||
|
response_model=InvitedUserResponse,
|
||||||
|
summary="Retry Invited User Tally",
|
||||||
|
)
|
||||||
|
async def retry_invited_user_tally_route(
|
||||||
|
invited_user_id: str,
|
||||||
|
admin_user_id: str = Security(get_user_id),
|
||||||
|
) -> InvitedUserResponse:
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s retrying Tally seed for invited user %s",
|
||||||
|
admin_user_id,
|
||||||
|
invited_user_id,
|
||||||
|
)
|
||||||
|
invited_user = await retry_invited_user_tally(invited_user_id)
|
||||||
|
logger.info(
|
||||||
|
"Admin user %s retried Tally seed for invited user %s",
|
||||||
|
admin_user_id,
|
||||||
|
invited_user_id,
|
||||||
|
)
|
||||||
|
return InvitedUserResponse.from_record(invited_user)
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
import prisma.enums
|
||||||
|
import pytest
|
||||||
|
import pytest_mock
|
||||||
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||||
|
|
||||||
|
from backend.data.invited_user import (
|
||||||
|
BulkInvitedUserRowResult,
|
||||||
|
BulkInvitedUsersResult,
|
||||||
|
InvitedUserRecord,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .user_admin_routes import router as user_admin_router
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(user_admin_router)
|
||||||
|
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_app_admin_auth(mock_jwt_admin):
|
||||||
|
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_invited_user() -> InvitedUserRecord:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
return InvitedUserRecord(
|
||||||
|
id="invite-1",
|
||||||
|
email="invited@example.com",
|
||||||
|
status=prisma.enums.InvitedUserStatus.INVITED,
|
||||||
|
auth_user_id=None,
|
||||||
|
name="Invited User",
|
||||||
|
tally_understanding=None,
|
||||||
|
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
||||||
|
tally_computed_at=None,
|
||||||
|
tally_error=None,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
||||||
|
return BulkInvitedUsersResult(
|
||||||
|
created_count=1,
|
||||||
|
skipped_count=1,
|
||||||
|
error_count=0,
|
||||||
|
results=[
|
||||||
|
BulkInvitedUserRowResult(
|
||||||
|
row_number=1,
|
||||||
|
email="invited@example.com",
|
||||||
|
name=None,
|
||||||
|
status="CREATED",
|
||||||
|
message="Invite created",
|
||||||
|
invited_user=_sample_invited_user(),
|
||||||
|
),
|
||||||
|
BulkInvitedUserRowResult(
|
||||||
|
row_number=2,
|
||||||
|
email="duplicate@example.com",
|
||||||
|
name=None,
|
||||||
|
status="SKIPPED",
|
||||||
|
message="An invited user with this email already exists",
|
||||||
|
invited_user=None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_invited_users(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
||||||
|
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/admin/invited-users")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["invited_users"]) == 1
|
||||||
|
assert data["invited_users"][0]["email"] == "invited@example.com"
|
||||||
|
assert data["invited_users"][0]["status"] == "INVITED"
|
||||||
|
assert data["pagination"]["total_items"] == 1
|
||||||
|
assert data["pagination"]["current_page"] == 1
|
||||||
|
assert data["pagination"]["page_size"] == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_invited_user(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
||||||
|
AsyncMock(return_value=_sample_invited_user()),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/invited-users",
|
||||||
|
json={"email": "invited@example.com", "name": "Invited User"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["email"] == "invited@example.com"
|
||||||
|
assert data["name"] == "Invited User"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bulk_create_invited_users(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
||||||
|
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/admin/invited-users/bulk",
|
||||||
|
files={
|
||||||
|
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["created_count"] == 1
|
||||||
|
assert data["skipped_count"] == 1
|
||||||
|
assert data["results"][0]["status"] == "CREATED"
|
||||||
|
assert data["results"][1]["status"] == "SKIPPED"
|
||||||
|
|
||||||
|
|
||||||
|
def test_revoke_invited_user(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
revoked = _sample_invited_user().model_copy(
|
||||||
|
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
||||||
|
AsyncMock(return_value=revoked),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/invited-users/invite-1/revoke")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["status"] == "REVOKED"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_invited_user_tally(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
retried = _sample_invited_user().model_copy(
|
||||||
|
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
||||||
|
AsyncMock(return_value=retried),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["tally_status"] == "RUNNING"
|
||||||
@@ -1,15 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from typing import Sequence
|
from typing import Any, Sequence, get_args, get_origin
|
||||||
|
|
||||||
import prisma
|
import prisma
|
||||||
|
from prisma.enums import ContentType
|
||||||
|
from prisma.models import mv_suggested_blocks
|
||||||
|
|
||||||
import backend.api.features.library.db as library_db
|
import backend.api.features.library.db as library_db
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.db as store_db
|
import backend.api.features.store.db as store_db
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
from backend.blocks._base import (
|
from backend.blocks._base import (
|
||||||
AnyBlockSchema,
|
AnyBlockSchema,
|
||||||
@@ -19,7 +21,6 @@ from backend.blocks._base import (
|
|||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
from backend.data.db import query_raw_with_schema
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
@@ -42,6 +43,16 @@ MAX_LIBRARY_AGENT_RESULTS = 100
|
|||||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||||
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
||||||
|
|
||||||
|
# Boost blocks over marketplace agents in search results
|
||||||
|
BLOCK_SCORE_BOOST = 50.0
|
||||||
|
|
||||||
|
# Block IDs to exclude from search results
|
||||||
|
EXCLUDED_BLOCK_IDS = frozenset(
|
||||||
|
{
|
||||||
|
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
||||||
|
|
||||||
|
|
||||||
@@ -64,8 +75,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
|||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
# Skip disabled blocks
|
# Skip disabled and excluded blocks
|
||||||
if block.disabled:
|
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
|
||||||
continue
|
continue
|
||||||
# Skip blocks that don't have categories (all should have at least one)
|
# Skip blocks that don't have categories (all should have at least one)
|
||||||
if not block.categories:
|
if not block.categories:
|
||||||
@@ -116,6 +127,9 @@ def get_blocks(
|
|||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
|
# Skip excluded blocks
|
||||||
|
if block.id in EXCLUDED_BLOCK_IDS:
|
||||||
|
continue
|
||||||
# Skip blocks that don't match the category
|
# Skip blocks that don't match the category
|
||||||
if category and category not in {c.name.lower() for c in block.categories}:
|
if category and category not in {c.name.lower() for c in block.categories}:
|
||||||
continue
|
continue
|
||||||
@@ -255,14 +269,25 @@ async def _build_cached_search_results(
|
|||||||
"my_agents": 0,
|
"my_agents": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
block_results, block_total, integration_total = _collect_block_results(
|
# Use hybrid search when query is present, otherwise list all blocks
|
||||||
normalized_query=normalized_query,
|
if (include_blocks or include_integrations) and normalized_query:
|
||||||
include_blocks=include_blocks,
|
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
||||||
include_integrations=include_integrations,
|
query=search_query,
|
||||||
)
|
include_blocks=include_blocks,
|
||||||
scored_items.extend(block_results)
|
include_integrations=include_integrations,
|
||||||
total_items["blocks"] = block_total
|
)
|
||||||
total_items["integrations"] = integration_total
|
scored_items.extend(block_results)
|
||||||
|
total_items["blocks"] = block_total
|
||||||
|
total_items["integrations"] = integration_total
|
||||||
|
elif include_blocks or include_integrations:
|
||||||
|
# No query - list all blocks using in-memory approach
|
||||||
|
block_results, block_total, integration_total = _collect_block_results(
|
||||||
|
include_blocks=include_blocks,
|
||||||
|
include_integrations=include_integrations,
|
||||||
|
)
|
||||||
|
scored_items.extend(block_results)
|
||||||
|
total_items["blocks"] = block_total
|
||||||
|
total_items["integrations"] = integration_total
|
||||||
|
|
||||||
if include_library_agents:
|
if include_library_agents:
|
||||||
library_response = await library_db.list_library_agents(
|
library_response = await library_db.list_library_agents(
|
||||||
@@ -307,10 +332,14 @@ async def _build_cached_search_results(
|
|||||||
|
|
||||||
def _collect_block_results(
|
def _collect_block_results(
|
||||||
*,
|
*,
|
||||||
normalized_query: str,
|
|
||||||
include_blocks: bool,
|
include_blocks: bool,
|
||||||
include_integrations: bool,
|
include_integrations: bool,
|
||||||
) -> tuple[list[_ScoredItem], int, int]:
|
) -> tuple[list[_ScoredItem], int, int]:
|
||||||
|
"""
|
||||||
|
Collect all blocks for listing (no search query).
|
||||||
|
|
||||||
|
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
|
||||||
|
"""
|
||||||
results: list[_ScoredItem] = []
|
results: list[_ScoredItem] = []
|
||||||
block_count = 0
|
block_count = 0
|
||||||
integration_count = 0
|
integration_count = 0
|
||||||
@@ -323,6 +352,10 @@ def _collect_block_results(
|
|||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Skip excluded blocks
|
||||||
|
if block.id in EXCLUDED_BLOCK_IDS:
|
||||||
|
continue
|
||||||
|
|
||||||
block_info = block.get_info()
|
block_info = block.get_info()
|
||||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||||
is_integration = len(credentials) > 0
|
is_integration = len(credentials) > 0
|
||||||
@@ -332,10 +365,6 @@ def _collect_block_results(
|
|||||||
if not is_integration and not include_blocks:
|
if not is_integration and not include_blocks:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
score = _score_block(block, block_info, normalized_query)
|
|
||||||
if not _should_include_item(score, normalized_query):
|
|
||||||
continue
|
|
||||||
|
|
||||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||||
if is_integration:
|
if is_integration:
|
||||||
integration_count += 1
|
integration_count += 1
|
||||||
@@ -346,8 +375,122 @@ def _collect_block_results(
|
|||||||
_ScoredItem(
|
_ScoredItem(
|
||||||
item=block_info,
|
item=block_info,
|
||||||
filter_type=filter_type,
|
filter_type=filter_type,
|
||||||
score=score,
|
score=BLOCK_SCORE_BOOST,
|
||||||
sort_key=_get_item_name(block_info),
|
sort_key=block_info.name.lower(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results, block_count, integration_count
|
||||||
|
|
||||||
|
|
||||||
|
async def _hybrid_search_blocks(
|
||||||
|
*,
|
||||||
|
query: str,
|
||||||
|
include_blocks: bool,
|
||||||
|
include_integrations: bool,
|
||||||
|
) -> tuple[list[_ScoredItem], int, int]:
|
||||||
|
"""
|
||||||
|
Search blocks using hybrid search with builder-specific filtering.
|
||||||
|
|
||||||
|
Uses unified_hybrid_search for semantic + lexical search, then applies
|
||||||
|
post-filtering for block/integration types and scoring adjustments.
|
||||||
|
|
||||||
|
Scoring:
|
||||||
|
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
||||||
|
to prioritize blocks over marketplace agents in combined results
|
||||||
|
- +30 for exact name match, +15 for prefix name match
|
||||||
|
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query string
|
||||||
|
include_blocks: Whether to include regular blocks
|
||||||
|
include_integrations: Whether to include integration blocks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (scored_items, block_count, integration_count)
|
||||||
|
"""
|
||||||
|
results: list[_ScoredItem] = []
|
||||||
|
block_count = 0
|
||||||
|
integration_count = 0
|
||||||
|
|
||||||
|
if not include_blocks and not include_integrations:
|
||||||
|
return results, block_count, integration_count
|
||||||
|
|
||||||
|
normalized_query = query.strip().lower()
|
||||||
|
|
||||||
|
# Fetch more results to account for post-filtering
|
||||||
|
search_results, _ = await unified_hybrid_search(
|
||||||
|
query=query,
|
||||||
|
content_types=[ContentType.BLOCK],
|
||||||
|
page=1,
|
||||||
|
page_size=150,
|
||||||
|
min_score=0.10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all blocks for getting BlockInfo
|
||||||
|
all_blocks = load_all_blocks()
|
||||||
|
|
||||||
|
for result in search_results:
|
||||||
|
block_id = result["content_id"]
|
||||||
|
|
||||||
|
# Skip excluded blocks
|
||||||
|
if block_id in EXCLUDED_BLOCK_IDS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = result.get("metadata", {})
|
||||||
|
hybrid_score = result.get("relevance", 0.0)
|
||||||
|
|
||||||
|
# Get the actual block class
|
||||||
|
if block_id not in all_blocks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
block_cls = all_blocks[block_id]
|
||||||
|
block: AnyBlockSchema = block_cls()
|
||||||
|
|
||||||
|
if block.disabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check block/integration filter using metadata
|
||||||
|
is_integration = metadata.get("is_integration", False)
|
||||||
|
|
||||||
|
if is_integration and not include_integrations:
|
||||||
|
continue
|
||||||
|
if not is_integration and not include_blocks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get block info
|
||||||
|
block_info = block.get_info()
|
||||||
|
|
||||||
|
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
||||||
|
# Hybrid scores are 0-1, builder scores were 0-200+
|
||||||
|
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
||||||
|
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
||||||
|
|
||||||
|
# Add LLM model match bonus
|
||||||
|
has_llm_field = metadata.get("has_llm_model_field", False)
|
||||||
|
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
||||||
|
final_score += 20
|
||||||
|
|
||||||
|
# Add exact/prefix match bonus for deterministic tie-breaking
|
||||||
|
name = block_info.name.lower()
|
||||||
|
if name == normalized_query:
|
||||||
|
final_score += 30
|
||||||
|
elif name.startswith(normalized_query):
|
||||||
|
final_score += 15
|
||||||
|
|
||||||
|
# Track counts
|
||||||
|
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||||
|
if is_integration:
|
||||||
|
integration_count += 1
|
||||||
|
else:
|
||||||
|
block_count += 1
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
_ScoredItem(
|
||||||
|
item=block_info,
|
||||||
|
filter_type=filter_type,
|
||||||
|
score=final_score,
|
||||||
|
sort_key=name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -472,6 +615,8 @@ async def _get_static_counts():
|
|||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
|
if block.id in EXCLUDED_BLOCK_IDS:
|
||||||
|
continue
|
||||||
|
|
||||||
all_blocks += 1
|
all_blocks += 1
|
||||||
|
|
||||||
@@ -498,47 +643,25 @@ async def _get_static_counts():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _contains_type(annotation: Any, target: type) -> bool:
|
||||||
|
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
||||||
|
if annotation is target:
|
||||||
|
return True
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
if origin is None:
|
||||||
|
return False
|
||||||
|
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
||||||
|
|
||||||
|
|
||||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||||
for field in schema_cls.model_fields.values():
|
for field in schema_cls.model_fields.values():
|
||||||
if field.annotation == LlmModel:
|
if _contains_type(field.annotation, LlmModel):
|
||||||
# Check if query matches any value in llm_models
|
# Check if query matches any value in llm_models
|
||||||
if any(query in name for name in llm_models):
|
if any(query in name for name in llm_models):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _score_block(
|
|
||||||
block: AnyBlockSchema,
|
|
||||||
block_info: BlockInfo,
|
|
||||||
normalized_query: str,
|
|
||||||
) -> float:
|
|
||||||
if not normalized_query:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
name = block_info.name.lower()
|
|
||||||
description = block_info.description.lower()
|
|
||||||
score = _score_primary_fields(name, description, normalized_query)
|
|
||||||
|
|
||||||
category_text = " ".join(
|
|
||||||
category.get("category", "").lower() for category in block_info.categories
|
|
||||||
)
|
|
||||||
score += _score_additional_field(category_text, normalized_query, 12, 6)
|
|
||||||
|
|
||||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
|
||||||
provider_names = [
|
|
||||||
provider.value.lower()
|
|
||||||
for info in credentials_info
|
|
||||||
for provider in info.provider
|
|
||||||
]
|
|
||||||
provider_text = " ".join(provider_names)
|
|
||||||
score += _score_additional_field(provider_text, normalized_query, 15, 6)
|
|
||||||
|
|
||||||
if _matches_llm_model(block.input_schema, normalized_query):
|
|
||||||
score += 20
|
|
||||||
|
|
||||||
return score
|
|
||||||
|
|
||||||
|
|
||||||
def _score_library_agent(
|
def _score_library_agent(
|
||||||
agent: library_model.LibraryAgent,
|
agent: library_model.LibraryAgent,
|
||||||
normalized_query: str,
|
normalized_query: str,
|
||||||
@@ -645,31 +768,20 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
|||||||
return providers
|
return providers
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
@cached(ttl_seconds=3600, shared_cache=True)
|
||||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||||
suggested_blocks = []
|
"""Return the most-executed blocks from the last 14 days.
|
||||||
# Sum the number of executions for each block type
|
|
||||||
# Prisma cannot group by nested relations, so we do a raw query
|
|
||||||
# Calculate the cutoff timestamp
|
|
||||||
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
|
|
||||||
|
|
||||||
results = await query_raw_with_schema(
|
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
|
||||||
"""
|
and returns the top `count` blocks sorted by execution count, excluding
|
||||||
SELECT
|
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
|
||||||
agent_node."agentBlockId" AS block_id,
|
"""
|
||||||
COUNT(execution.id) AS execution_count
|
results = await mv_suggested_blocks.prisma().find_many()
|
||||||
FROM {schema_prefix}"AgentNodeExecution" execution
|
|
||||||
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
|
|
||||||
WHERE execution."endedTime" >= $1::timestamp
|
|
||||||
GROUP BY agent_node."agentBlockId"
|
|
||||||
ORDER BY execution_count DESC;
|
|
||||||
""",
|
|
||||||
timestamp_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the top blocks based on execution count
|
# Get the top blocks based on execution count
|
||||||
# But ignore Input and Output blocks
|
# But ignore Input, Output, Agent, and excluded blocks
|
||||||
blocks: list[tuple[BlockInfo, int]] = []
|
blocks: list[tuple[BlockInfo, int]] = []
|
||||||
|
execution_counts = {row.block_id: row.execution_count for row in results}
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
@@ -679,11 +791,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
|||||||
BlockType.AGENT,
|
BlockType.AGENT,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# Find the execution count for this block
|
if block.id in EXCLUDED_BLOCK_IDS:
|
||||||
execution_count = next(
|
continue
|
||||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
execution_count = execution_counts.get(block.id, 0)
|
||||||
0,
|
|
||||||
)
|
|
||||||
blocks.append((block.get_info(), execution_count))
|
blocks.append((block.get_info(), execution_count))
|
||||||
# Sort blocks by execution count
|
# Sort blocks by execution count
|
||||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ class SearchEntry(BaseModel):
|
|||||||
|
|
||||||
# Suggestions
|
# Suggestions
|
||||||
class SuggestionsResponse(BaseModel):
|
class SuggestionsResponse(BaseModel):
|
||||||
otto_suggestions: list[str]
|
|
||||||
recent_searches: list[SearchEntry]
|
recent_searches: list[SearchEntry]
|
||||||
providers: list[ProviderName]
|
providers: list[ProviderName]
|
||||||
top_blocks: list[BlockInfo]
|
top_blocks: list[BlockInfo]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Annotated, Sequence
|
from typing import Annotated, Sequence, cast, get_args
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
@@ -10,6 +10,8 @@ from backend.util.models import Pagination
|
|||||||
from . import db as builder_db
|
from . import db as builder_db
|
||||||
from . import model as builder_model
|
from . import model as builder_model
|
||||||
|
|
||||||
|
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = fastapi.APIRouter(
|
router = fastapi.APIRouter(
|
||||||
@@ -49,11 +51,6 @@ async def get_suggestions(
|
|||||||
Get all suggestions for the Blocks Menu.
|
Get all suggestions for the Blocks Menu.
|
||||||
"""
|
"""
|
||||||
return builder_model.SuggestionsResponse(
|
return builder_model.SuggestionsResponse(
|
||||||
otto_suggestions=[
|
|
||||||
"What blocks do I need to get started?",
|
|
||||||
"Help me create a list",
|
|
||||||
"Help me feed my data to Google Maps",
|
|
||||||
],
|
|
||||||
recent_searches=await builder_db.get_recent_searches(user_id),
|
recent_searches=await builder_db.get_recent_searches(user_id),
|
||||||
providers=[
|
providers=[
|
||||||
ProviderName.TWITTER,
|
ProviderName.TWITTER,
|
||||||
@@ -151,7 +148,7 @@ async def get_providers(
|
|||||||
async def search(
|
async def search(
|
||||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
search_query: Annotated[str | None, fastapi.Query()] = None,
|
search_query: Annotated[str | None, fastapi.Query()] = None,
|
||||||
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
|
filter: Annotated[str | None, fastapi.Query()] = None,
|
||||||
search_id: Annotated[str | None, fastapi.Query()] = None,
|
search_id: Annotated[str | None, fastapi.Query()] = None,
|
||||||
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
@@ -160,9 +157,20 @@ async def search(
|
|||||||
"""
|
"""
|
||||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
Search for blocks (including integrations), marketplace agents, and user library agents.
|
||||||
"""
|
"""
|
||||||
# If no filters are provided, then we will return all types
|
# Parse and validate filter parameter
|
||||||
if not filter:
|
filters: list[builder_model.FilterType]
|
||||||
filter = [
|
if filter:
|
||||||
|
filter_values = [f.strip() for f in filter.split(",")]
|
||||||
|
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
|
||||||
|
if invalid_filters:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
|
||||||
|
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
|
||||||
|
)
|
||||||
|
filters = cast(list[builder_model.FilterType], filter_values)
|
||||||
|
else:
|
||||||
|
filters = [
|
||||||
"blocks",
|
"blocks",
|
||||||
"integrations",
|
"integrations",
|
||||||
"marketplace_agents",
|
"marketplace_agents",
|
||||||
@@ -174,7 +182,7 @@ async def search(
|
|||||||
cached_results = await builder_db.get_sorted_search_results(
|
cached_results = await builder_db.get_sorted_search_results(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
filters=filter,
|
filters=filters,
|
||||||
by_creator=by_creator,
|
by_creator=by_creator,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -196,7 +204,7 @@ async def search(
|
|||||||
user_id,
|
user_id,
|
||||||
builder_model.SearchEntry(
|
builder_model.SearchEntry(
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
filter=filter,
|
filter=filters,
|
||||||
by_creator=by_creator,
|
by_creator=by_creator,
|
||||||
search_id=search_id,
|
search_id=search_id,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -1,368 +0,0 @@
|
|||||||
"""Redis Streams consumer for operation completion messages.
|
|
||||||
|
|
||||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
|
||||||
completion notifications (OperationCompleteMessage) from external services
|
|
||||||
(like Agent Generator) and triggers the appropriate stream registry and
|
|
||||||
chat service updates via process_operation_success/process_operation_failure.
|
|
||||||
|
|
||||||
Why Redis Streams instead of RabbitMQ?
|
|
||||||
--------------------------------------
|
|
||||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
|
||||||
queue), Redis Streams was chosen for chat completion notifications because:
|
|
||||||
|
|
||||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
|
||||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
|
||||||
Streams for completion notifications keeps all chat streaming infrastructure
|
|
||||||
in one system, simplifying operations and reducing cross-system coordination.
|
|
||||||
|
|
||||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
|
||||||
allowing consumers to replay missed messages after reconnection. This aligns
|
|
||||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
|
||||||
|
|
||||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
|
||||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
|
||||||
recovering from dead consumers - ideal for the completion callback pattern.
|
|
||||||
|
|
||||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
|
||||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
|
||||||
|
|
||||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
|
||||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
|
||||||
transactional semantics without distributed coordination.
|
|
||||||
|
|
||||||
The consumer uses Redis Streams with consumer groups for reliable message
|
|
||||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
|
||||||
stale pending messages from dead consumers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
from prisma import Prisma
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from redis.exceptions import ResponseError
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
|
|
||||||
from . import stream_registry
|
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
|
||||||
from .config import ChatConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
|
|
||||||
class OperationCompleteMessage(BaseModel):
|
|
||||||
"""Message format for operation completion notifications."""
|
|
||||||
|
|
||||||
operation_id: str
|
|
||||||
task_id: str
|
|
||||||
success: bool
|
|
||||||
result: dict | str | None = None
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionConsumer:
|
|
||||||
"""Consumer for chat operation completion messages from Redis Streams.
|
|
||||||
|
|
||||||
This consumer initializes its own Prisma client in start() to ensure
|
|
||||||
database operations work correctly within this async context.
|
|
||||||
|
|
||||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
|
||||||
messages reliably with automatic redelivery on failure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._consumer_task: asyncio.Task | None = None
|
|
||||||
self._running = False
|
|
||||||
self._prisma: Prisma | None = None
|
|
||||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
"""Start the completion consumer."""
|
|
||||||
if self._running:
|
|
||||||
logger.warning("Completion consumer already running")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create consumer group if it doesn't exist
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.xgroup_create(
|
|
||||||
config.stream_completion_name,
|
|
||||||
config.stream_consumer_group,
|
|
||||||
id="0",
|
|
||||||
mkstream=True,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Created consumer group '{config.stream_consumer_group}' "
|
|
||||||
f"on stream '{config.stream_completion_name}'"
|
|
||||||
)
|
|
||||||
except ResponseError as e:
|
|
||||||
if "BUSYGROUP" in str(e):
|
|
||||||
logger.debug(
|
|
||||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
|
||||||
logger.info(
|
|
||||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_prisma(self) -> Prisma:
|
|
||||||
"""Lazily initialize Prisma client on first use."""
|
|
||||||
if self._prisma is None:
|
|
||||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
|
||||||
self._prisma = Prisma(datasource={"url": database_url})
|
|
||||||
await self._prisma.connect()
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
|
||||||
return self._prisma
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
"""Stop the completion consumer."""
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
if self._consumer_task:
|
|
||||||
self._consumer_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._consumer_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._consumer_task = None
|
|
||||||
|
|
||||||
if self._prisma:
|
|
||||||
await self._prisma.disconnect()
|
|
||||||
self._prisma = None
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
|
||||||
|
|
||||||
logger.info("Chat completion consumer stopped")
|
|
||||||
|
|
||||||
async def _consume_messages(self) -> None:
|
|
||||||
"""Main message consumption loop with retry logic."""
|
|
||||||
max_retries = 10
|
|
||||||
retry_delay = 5 # seconds
|
|
||||||
retry_count = 0
|
|
||||||
block_timeout = 5000 # milliseconds
|
|
||||||
|
|
||||||
while self._running and retry_count < max_retries:
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
|
|
||||||
# Reset retry count on successful connection
|
|
||||||
retry_count = 0
|
|
||||||
|
|
||||||
while self._running:
|
|
||||||
# First, claim any stale pending messages from dead consumers
|
|
||||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
|
||||||
# claim them using XAUTOCLAIM
|
|
||||||
try:
|
|
||||||
claimed_result = await redis.xautoclaim(
|
|
||||||
name=config.stream_completion_name,
|
|
||||||
groupname=config.stream_consumer_group,
|
|
||||||
consumername=self._consumer_name,
|
|
||||||
min_idle_time=config.stream_claim_min_idle_ms,
|
|
||||||
start_id="0-0",
|
|
||||||
count=10,
|
|
||||||
)
|
|
||||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
|
||||||
if claimed_result and len(claimed_result) >= 2:
|
|
||||||
claimed_entries = claimed_result[1]
|
|
||||||
if claimed_entries:
|
|
||||||
logger.info(
|
|
||||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
|
||||||
)
|
|
||||||
for entry_id, data in claimed_entries:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
await self._process_entry(redis, entry_id, data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
|
||||||
|
|
||||||
# Read new messages from the stream
|
|
||||||
messages = await redis.xreadgroup(
|
|
||||||
groupname=config.stream_consumer_group,
|
|
||||||
consumername=self._consumer_name,
|
|
||||||
streams={config.stream_completion_name: ">"},
|
|
||||||
block=block_timeout,
|
|
||||||
count=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for stream_name, entries in messages:
|
|
||||||
for entry_id, data in entries:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
await self._process_entry(redis, entry_id, data)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("Consumer cancelled")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
retry_count += 1
|
|
||||||
logger.error(
|
|
||||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
if self._running and retry_count < max_retries:
|
|
||||||
await asyncio.sleep(retry_delay)
|
|
||||||
else:
|
|
||||||
logger.error("Max retries reached, stopping consumer")
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _process_entry(
|
|
||||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
"""Process a single stream entry and acknowledge it on success.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
redis: Redis client connection
|
|
||||||
entry_id: The stream entry ID
|
|
||||||
data: The entry data dict
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Handle the message
|
|
||||||
message_data = data.get("data")
|
|
||||||
if message_data:
|
|
||||||
await self._handle_message(
|
|
||||||
message_data.encode()
|
|
||||||
if isinstance(message_data, str)
|
|
||||||
else message_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Acknowledge the message after successful processing
|
|
||||||
await redis.xack(
|
|
||||||
config.stream_completion_name,
|
|
||||||
config.stream_consumer_group,
|
|
||||||
entry_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error processing completion message {entry_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Message remains in pending state and will be claimed by
|
|
||||||
# XAUTOCLAIM after min_idle_time expires
|
|
||||||
|
|
||||||
async def _handle_message(self, body: bytes) -> None:
|
|
||||||
"""Handle a completion message using our own Prisma client."""
|
|
||||||
try:
|
|
||||||
data = orjson.loads(body)
|
|
||||||
message = OperationCompleteMessage(**data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to parse completion message: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
|
||||||
f"(task_id={message.task_id}, success={message.success})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find task in registry
|
|
||||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
|
||||||
if task is None:
|
|
||||||
task = await stream_registry.get_task(message.task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
logger.warning(
|
|
||||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
|
||||||
f"(task_id={message.task_id})"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
|
||||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Guard against empty task fields
|
|
||||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Task has empty critical fields! "
|
|
||||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
|
||||||
f"tool_call_id={task.tool_call_id!r}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.success:
|
|
||||||
await self._handle_success(task, message)
|
|
||||||
else:
|
|
||||||
await self._handle_failure(task, message)
|
|
||||||
|
|
||||||
async def _handle_success(
|
|
||||||
self,
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
message: OperationCompleteMessage,
|
|
||||||
) -> None:
|
|
||||||
"""Handle successful operation completion."""
|
|
||||||
prisma = await self._ensure_prisma()
|
|
||||||
await process_operation_success(task, message.result, prisma)
|
|
||||||
|
|
||||||
async def _handle_failure(
|
|
||||||
self,
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
message: OperationCompleteMessage,
|
|
||||||
) -> None:
|
|
||||||
"""Handle failed operation completion."""
|
|
||||||
prisma = await self._ensure_prisma()
|
|
||||||
await process_operation_failure(task, message.error, prisma)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level consumer instance
|
|
||||||
_consumer: ChatCompletionConsumer | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def start_completion_consumer() -> None:
|
|
||||||
"""Start the global completion consumer."""
|
|
||||||
global _consumer
|
|
||||||
if _consumer is None:
|
|
||||||
_consumer = ChatCompletionConsumer()
|
|
||||||
await _consumer.start()
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_completion_consumer() -> None:
|
|
||||||
"""Stop the global completion consumer."""
|
|
||||||
global _consumer
|
|
||||||
if _consumer:
|
|
||||||
await _consumer.stop()
|
|
||||||
_consumer = None
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_operation_complete(
|
|
||||||
operation_id: str,
|
|
||||||
task_id: str,
|
|
||||||
success: bool,
|
|
||||||
result: dict | str | None = None,
|
|
||||||
error: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Publish an operation completion message to Redis Streams.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: The operation ID that completed.
|
|
||||||
task_id: The task ID associated with the operation.
|
|
||||||
success: Whether the operation succeeded.
|
|
||||||
result: The result data (for success).
|
|
||||||
error: The error message (for failure).
|
|
||||||
"""
|
|
||||||
message = OperationCompleteMessage(
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
success=success,
|
|
||||||
result=result,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.xadd(
|
|
||||||
config.stream_completion_name,
|
|
||||||
{"data": message.model_dump_json()},
|
|
||||||
maxlen=config.stream_max_length,
|
|
||||||
)
|
|
||||||
logger.info(f"Published completion for operation {operation_id}")
|
|
||||||
@@ -1,344 +0,0 @@
|
|||||||
"""Shared completion handling for operation success and failure.
|
|
||||||
|
|
||||||
This module provides common logic for handling operation completion from both:
|
|
||||||
- The Redis Streams consumer (completion_consumer.py)
|
|
||||||
- The HTTP webhook endpoint (routes.py)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
from prisma import Prisma
|
|
||||||
|
|
||||||
from . import service as chat_service
|
|
||||||
from . import stream_registry
|
|
||||||
from .response_model import StreamError, StreamToolOutputAvailable
|
|
||||||
from .tools.models import ErrorResponse
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Tools that produce agent_json that needs to be saved to library
|
|
||||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
|
||||||
|
|
||||||
# Keys that should be stripped from agent_json when returning in error responses
|
|
||||||
SENSITIVE_KEYS = frozenset(
|
|
||||||
{
|
|
||||||
"api_key",
|
|
||||||
"apikey",
|
|
||||||
"api_secret",
|
|
||||||
"password",
|
|
||||||
"secret",
|
|
||||||
"credentials",
|
|
||||||
"credential",
|
|
||||||
"token",
|
|
||||||
"access_token",
|
|
||||||
"refresh_token",
|
|
||||||
"private_key",
|
|
||||||
"privatekey",
|
|
||||||
"auth",
|
|
||||||
"authorization",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_agent_json(obj: Any) -> Any:
|
|
||||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: The object to sanitize (dict, list, or primitive)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized copy with sensitive keys removed/redacted
|
|
||||||
"""
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {
|
|
||||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
|
||||||
for k, v in obj.items()
|
|
||||||
}
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [_sanitize_agent_json(item) for item in obj]
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
class ToolMessageUpdateError(Exception):
|
|
||||||
"""Raised when updating a tool message in the database fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_tool_message(
|
|
||||||
session_id: str,
|
|
||||||
tool_call_id: str,
|
|
||||||
content: str,
|
|
||||||
prisma_client: Prisma | None,
|
|
||||||
) -> None:
|
|
||||||
"""Update tool message in database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The session ID
|
|
||||||
tool_call_id: The tool call ID to update
|
|
||||||
content: The new content for the message
|
|
||||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ToolMessageUpdateError: If the database update fails. The caller should
|
|
||||||
handle this to avoid marking the task as completed with inconsistent state.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if prisma_client:
|
|
||||||
# Use provided Prisma client (for consumer with its own connection)
|
|
||||||
updated_count = await prisma_client.chatmessage.update_many(
|
|
||||||
where={
|
|
||||||
"sessionId": session_id,
|
|
||||||
"toolCallId": tool_call_id,
|
|
||||||
},
|
|
||||||
data={"content": content},
|
|
||||||
)
|
|
||||||
# Check if any rows were updated - 0 means message not found
|
|
||||||
if updated_count == 0:
|
|
||||||
raise ToolMessageUpdateError(
|
|
||||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use service function (for webhook endpoint)
|
|
||||||
await chat_service._update_pending_operation(
|
|
||||||
session_id=session_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
result=content,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
|
||||||
raise ToolMessageUpdateError(
|
|
||||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
|
||||||
"""Serialize result to JSON string with sensible defaults.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: The result to serialize. Can be a dict, list, string,
|
|
||||||
number, boolean, or None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
|
||||||
only when result is explicitly None.
|
|
||||||
"""
|
|
||||||
if isinstance(result, str):
|
|
||||||
return result
|
|
||||||
if result is None:
|
|
||||||
return '{"status": "completed"}'
|
|
||||||
return orjson.dumps(result).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
async def _save_agent_from_result(
|
|
||||||
result: dict[str, Any],
|
|
||||||
user_id: str | None,
|
|
||||||
tool_name: str,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Save agent to library if result contains agent_json.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: The result dict that may contain agent_json
|
|
||||||
user_id: The user ID to save the agent for
|
|
||||||
tool_name: The tool name (create_agent or edit_agent)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated result dict with saved agent details, or original result if no agent_json
|
|
||||||
"""
|
|
||||||
if not user_id:
|
|
||||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
|
||||||
return result
|
|
||||||
|
|
||||||
agent_json = result.get("agent_json")
|
|
||||||
if not agent_json:
|
|
||||||
logger.warning(
|
|
||||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .tools.agent_generator import save_agent_to_library
|
|
||||||
|
|
||||||
is_update = tool_name == "edit_agent"
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
agent_json, user_id, is_update=is_update
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
|
||||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return a response similar to AgentSavedResponse
|
|
||||||
return {
|
|
||||||
"type": "agent_saved",
|
|
||||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
|
||||||
"agent_id": created_graph.id,
|
|
||||||
"agent_name": created_graph.name,
|
|
||||||
"library_agent_id": library_agent.id,
|
|
||||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
|
||||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Return error but don't fail the whole operation
|
|
||||||
# Sanitize agent_json to remove sensitive keys before returning
|
|
||||||
return {
|
|
||||||
"type": "error",
|
|
||||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
|
||||||
"error": str(e),
|
|
||||||
"agent_json": _sanitize_agent_json(agent_json),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def process_operation_success(
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
result: dict | str | None,
|
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Handle successful operation completion.
|
|
||||||
|
|
||||||
Publishes the result to the stream registry, updates the database,
|
|
||||||
generates LLM continuation, and marks the task as completed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The active task that completed
|
|
||||||
result: The result data from the operation
|
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ToolMessageUpdateError: If the database update fails. The task will be
|
|
||||||
marked as failed instead of completed to avoid inconsistent state.
|
|
||||||
"""
|
|
||||||
# For agent generation tools, save the agent to library
|
|
||||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
|
||||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
|
||||||
|
|
||||||
# Serialize result for output (only substitute default when result is exactly None)
|
|
||||||
result_output = result if result is not None else {"status": "completed"}
|
|
||||||
output_str = (
|
|
||||||
result_output
|
|
||||||
if isinstance(result_output, str)
|
|
||||||
else orjson.dumps(result_output).decode("utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Publish result to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamToolOutputAvailable(
|
|
||||||
toolCallId=task.tool_call_id,
|
|
||||||
toolName=task.tool_name,
|
|
||||||
output=output_str,
|
|
||||||
success=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update pending operation in database
|
|
||||||
# If this fails, we must not continue to mark the task as completed
|
|
||||||
result_str = serialize_result(result)
|
|
||||||
try:
|
|
||||||
await _update_tool_message(
|
|
||||||
session_id=task.session_id,
|
|
||||||
tool_call_id=task.tool_call_id,
|
|
||||||
content=result_str,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
# DB update failed - mark task as failed to avoid inconsistent state
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
|
||||||
"marking as failed instead of completed"
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamError(errorText="Failed to save operation result to database"),
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Generate LLM continuation with streaming
|
|
||||||
try:
|
|
||||||
await chat_service._generate_llm_continuation_with_streaming(
|
|
||||||
session_id=task.session_id,
|
|
||||||
user_id=task.user_id,
|
|
||||||
task_id=task.task_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as completed and release Redis lock
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
|
||||||
try:
|
|
||||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def process_operation_failure(
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
error: str | None,
|
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Handle failed operation completion.
|
|
||||||
|
|
||||||
Publishes the error to the stream registry, updates the database with
|
|
||||||
the error response, and marks the task as failed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The active task that failed
|
|
||||||
error: The error message from the operation
|
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
"""
|
|
||||||
error_msg = error or "Operation failed"
|
|
||||||
|
|
||||||
# Publish error to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamError(errorText=error_msg),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update pending operation with error
|
|
||||||
# If this fails, we still continue to mark the task as failed
|
|
||||||
error_response = ErrorResponse(
|
|
||||||
message=error_msg,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await _update_tool_message(
|
|
||||||
session_id=task.session_id,
|
|
||||||
tool_call_id=task.tool_call_id,
|
|
||||||
content=error_response.model_dump_json(),
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
# DB update failed - log but continue with cleanup
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
|
||||||
"continuing with cleanup"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as failed and release Redis lock
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
|
||||||
try:
|
|
||||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
|
||||||
|
|
||||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
|
||||||
@@ -1,152 +0,0 @@
|
|||||||
"""Configuration management for chat system."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
|
|
||||||
|
|
||||||
class ChatConfig(BaseSettings):
|
|
||||||
"""Configuration for the chat system."""
|
|
||||||
|
|
||||||
# OpenAI API Configuration
|
|
||||||
model: str = Field(
|
|
||||||
default="anthropic/claude-opus-4.6", description="Default model to use"
|
|
||||||
)
|
|
||||||
title_model: str = Field(
|
|
||||||
default="openai/gpt-4o-mini",
|
|
||||||
description="Model to use for generating session titles (should be fast/cheap)",
|
|
||||||
)
|
|
||||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
|
||||||
base_url: str | None = Field(
|
|
||||||
default="https://openrouter.ai/api/v1",
|
|
||||||
description="Base URL for API (e.g., for OpenRouter)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Session TTL Configuration - 12 hours
|
|
||||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
|
||||||
|
|
||||||
# Streaming Configuration
|
|
||||||
max_context_messages: int = Field(
|
|
||||||
default=50, ge=1, le=200, description="Maximum context messages"
|
|
||||||
)
|
|
||||||
|
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
|
||||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
|
||||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
|
||||||
max_agent_schedules: int = Field(
|
|
||||||
default=30, description="Maximum number of agent schedules"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Long-running operation configuration
|
|
||||||
long_running_operation_ttl: int = Field(
|
|
||||||
default=600,
|
|
||||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stream registry configuration for SSE reconnection
|
|
||||||
stream_ttl: int = Field(
|
|
||||||
default=3600,
|
|
||||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
|
||||||
)
|
|
||||||
stream_max_length: int = Field(
|
|
||||||
default=10000,
|
|
||||||
description="Maximum number of messages to store per stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redis Streams configuration for completion consumer
|
|
||||||
stream_completion_name: str = Field(
|
|
||||||
default="chat:completions",
|
|
||||||
description="Redis Stream name for operation completions",
|
|
||||||
)
|
|
||||||
stream_consumer_group: str = Field(
|
|
||||||
default="chat_consumers",
|
|
||||||
description="Consumer group name for completion stream",
|
|
||||||
)
|
|
||||||
stream_claim_min_idle_ms: int = Field(
|
|
||||||
default=60000,
|
|
||||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redis key prefixes for stream registry
|
|
||||||
task_meta_prefix: str = Field(
|
|
||||||
default="chat:task:meta:",
|
|
||||||
description="Prefix for task metadata hash keys",
|
|
||||||
)
|
|
||||||
task_stream_prefix: str = Field(
|
|
||||||
default="chat:stream:",
|
|
||||||
description="Prefix for task message stream keys",
|
|
||||||
)
|
|
||||||
task_op_prefix: str = Field(
|
|
||||||
default="chat:task:op:",
|
|
||||||
description="Prefix for operation ID to task ID mapping keys",
|
|
||||||
)
|
|
||||||
internal_api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
|
||||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
|
||||||
langfuse_prompt_name: str = Field(
|
|
||||||
default="CoPilot Prompt",
|
|
||||||
description="Name of the prompt in Langfuse to fetch",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extended thinking configuration for Claude models
|
|
||||||
thinking_enabled: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Enable adaptive thinking for Claude models via OpenRouter",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("api_key", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def get_api_key(cls, v):
|
|
||||||
"""Get API key from environment if not provided."""
|
|
||||||
if v is None:
|
|
||||||
# Try to get from environment variables
|
|
||||||
# First check for CHAT_API_KEY (Pydantic prefix)
|
|
||||||
v = os.getenv("CHAT_API_KEY")
|
|
||||||
if not v:
|
|
||||||
# Fall back to OPEN_ROUTER_API_KEY
|
|
||||||
v = os.getenv("OPEN_ROUTER_API_KEY")
|
|
||||||
if not v:
|
|
||||||
# Fall back to OPENAI_API_KEY
|
|
||||||
v = os.getenv("OPENAI_API_KEY")
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("base_url", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def get_base_url(cls, v):
|
|
||||||
"""Get base URL from environment if not provided."""
|
|
||||||
if v is None:
|
|
||||||
# Check for OpenRouter or custom base URL
|
|
||||||
v = os.getenv("CHAT_BASE_URL")
|
|
||||||
if not v:
|
|
||||||
v = os.getenv("OPENROUTER_BASE_URL")
|
|
||||||
if not v:
|
|
||||||
v = os.getenv("OPENAI_BASE_URL")
|
|
||||||
if not v:
|
|
||||||
v = "https://openrouter.ai/api/v1"
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("internal_api_key", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def get_internal_api_key(cls, v):
|
|
||||||
"""Get internal API key from environment if not provided."""
|
|
||||||
if v is None:
|
|
||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
|
||||||
return v
|
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
|
||||||
"default": "prompts/chat_system.md",
|
|
||||||
"onboarding": "prompts/onboarding_system.md",
|
|
||||||
}
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Pydantic config."""
|
|
||||||
|
|
||||||
env_file = ".env"
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
extra = "ignore" # Ignore extra environment variables
|
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
"""Database operations for chat sessions."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from prisma.models import ChatMessage as PrismaChatMessage
|
|
||||||
from prisma.models import ChatSession as PrismaChatSession
|
|
||||||
from prisma.types import (
|
|
||||||
ChatMessageCreateInput,
|
|
||||||
ChatSessionCreateInput,
|
|
||||||
ChatSessionUpdateInput,
|
|
||||||
ChatSessionWhereInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.data.db import transaction
|
|
||||||
from backend.util.json import SafeJson
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
|
||||||
"""Get a chat session by ID from the database."""
|
|
||||||
session = await PrismaChatSession.prisma().find_unique(
|
|
||||||
where={"id": session_id},
|
|
||||||
include={"Messages": True},
|
|
||||||
)
|
|
||||||
if session and session.Messages:
|
|
||||||
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
|
||||||
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
|
||||||
session.Messages.sort(key=lambda m: m.sequence)
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(
|
|
||||||
session_id: str,
|
|
||||||
user_id: str,
|
|
||||||
) -> PrismaChatSession:
|
|
||||||
"""Create a new chat session in the database."""
|
|
||||||
data = ChatSessionCreateInput(
|
|
||||||
id=session_id,
|
|
||||||
userId=user_id,
|
|
||||||
credentials=SafeJson({}),
|
|
||||||
successfulAgentRuns=SafeJson({}),
|
|
||||||
successfulAgentSchedules=SafeJson({}),
|
|
||||||
)
|
|
||||||
return await PrismaChatSession.prisma().create(data=data)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
|
||||||
session_id: str,
|
|
||||||
credentials: dict[str, Any] | None = None,
|
|
||||||
successful_agent_runs: dict[str, Any] | None = None,
|
|
||||||
successful_agent_schedules: dict[str, Any] | None = None,
|
|
||||||
total_prompt_tokens: int | None = None,
|
|
||||||
total_completion_tokens: int | None = None,
|
|
||||||
title: str | None = None,
|
|
||||||
) -> PrismaChatSession | None:
|
|
||||||
"""Update a chat session's metadata."""
|
|
||||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
|
||||||
|
|
||||||
if credentials is not None:
|
|
||||||
data["credentials"] = SafeJson(credentials)
|
|
||||||
if successful_agent_runs is not None:
|
|
||||||
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
|
||||||
if successful_agent_schedules is not None:
|
|
||||||
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
|
||||||
if total_prompt_tokens is not None:
|
|
||||||
data["totalPromptTokens"] = total_prompt_tokens
|
|
||||||
if total_completion_tokens is not None:
|
|
||||||
data["totalCompletionTokens"] = total_completion_tokens
|
|
||||||
if title is not None:
|
|
||||||
data["title"] = title
|
|
||||||
|
|
||||||
session = await PrismaChatSession.prisma().update(
|
|
||||||
where={"id": session_id},
|
|
||||||
data=data,
|
|
||||||
include={"Messages": True},
|
|
||||||
)
|
|
||||||
if session and session.Messages:
|
|
||||||
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
|
||||||
session.Messages.sort(key=lambda m: m.sequence)
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def add_chat_message(
|
|
||||||
session_id: str,
|
|
||||||
role: str,
|
|
||||||
sequence: int,
|
|
||||||
content: str | None = None,
|
|
||||||
name: str | None = None,
|
|
||||||
tool_call_id: str | None = None,
|
|
||||||
refusal: str | None = None,
|
|
||||||
tool_calls: list[dict[str, Any]] | None = None,
|
|
||||||
function_call: dict[str, Any] | None = None,
|
|
||||||
) -> PrismaChatMessage:
|
|
||||||
"""Add a message to a chat session."""
|
|
||||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
|
||||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
|
||||||
# We only include fields that have values, then cast at the end.
|
|
||||||
data: dict[str, Any] = {
|
|
||||||
"Session": {"connect": {"id": session_id}},
|
|
||||||
"role": role,
|
|
||||||
"sequence": sequence,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add optional string fields
|
|
||||||
if content is not None:
|
|
||||||
data["content"] = content
|
|
||||||
if name is not None:
|
|
||||||
data["name"] = name
|
|
||||||
if tool_call_id is not None:
|
|
||||||
data["toolCallId"] = tool_call_id
|
|
||||||
if refusal is not None:
|
|
||||||
data["refusal"] = refusal
|
|
||||||
|
|
||||||
# Add optional JSON fields only when they have values
|
|
||||||
if tool_calls is not None:
|
|
||||||
data["toolCalls"] = SafeJson(tool_calls)
|
|
||||||
if function_call is not None:
|
|
||||||
data["functionCall"] = SafeJson(function_call)
|
|
||||||
|
|
||||||
# Run message create and session timestamp update in parallel for lower latency
|
|
||||||
_, message = await asyncio.gather(
|
|
||||||
PrismaChatSession.prisma().update(
|
|
||||||
where={"id": session_id},
|
|
||||||
data={"updatedAt": datetime.now(UTC)},
|
|
||||||
),
|
|
||||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
|
||||||
)
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
async def add_chat_messages_batch(
|
|
||||||
session_id: str,
|
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
start_sequence: int,
|
|
||||||
) -> list[PrismaChatMessage]:
|
|
||||||
"""Add multiple messages to a chat session in a batch.
|
|
||||||
|
|
||||||
Uses a transaction for atomicity - if any message creation fails,
|
|
||||||
the entire batch is rolled back.
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return []
|
|
||||||
|
|
||||||
created_messages = []
|
|
||||||
|
|
||||||
async with transaction() as tx:
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
|
||||||
# directly because Prisma's TypedDict validation rejects optional fields
|
|
||||||
# set to None. We only include fields that have values, then cast.
|
|
||||||
data: dict[str, Any] = {
|
|
||||||
"Session": {"connect": {"id": session_id}},
|
|
||||||
"role": msg["role"],
|
|
||||||
"sequence": start_sequence + i,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add optional string fields
|
|
||||||
if msg.get("content") is not None:
|
|
||||||
data["content"] = msg["content"]
|
|
||||||
if msg.get("name") is not None:
|
|
||||||
data["name"] = msg["name"]
|
|
||||||
if msg.get("tool_call_id") is not None:
|
|
||||||
data["toolCallId"] = msg["tool_call_id"]
|
|
||||||
if msg.get("refusal") is not None:
|
|
||||||
data["refusal"] = msg["refusal"]
|
|
||||||
|
|
||||||
# Add optional JSON fields only when they have values
|
|
||||||
if msg.get("tool_calls") is not None:
|
|
||||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
|
||||||
if msg.get("function_call") is not None:
|
|
||||||
data["functionCall"] = SafeJson(msg["function_call"])
|
|
||||||
|
|
||||||
created = await PrismaChatMessage.prisma(tx).create(
|
|
||||||
data=cast(ChatMessageCreateInput, data)
|
|
||||||
)
|
|
||||||
created_messages.append(created)
|
|
||||||
|
|
||||||
# Update session's updatedAt timestamp within the same transaction.
|
|
||||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
|
||||||
# separately via update_chat_session() after streaming completes.
|
|
||||||
await PrismaChatSession.prisma(tx).update(
|
|
||||||
where={"id": session_id},
|
|
||||||
data={"updatedAt": datetime.now(UTC)},
|
|
||||||
)
|
|
||||||
|
|
||||||
return created_messages
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_chat_sessions(
|
|
||||||
user_id: str,
|
|
||||||
limit: int = 50,
|
|
||||||
offset: int = 0,
|
|
||||||
) -> list[PrismaChatSession]:
|
|
||||||
"""Get chat sessions for a user, ordered by most recent."""
|
|
||||||
return await PrismaChatSession.prisma().find_many(
|
|
||||||
where={"userId": user_id},
|
|
||||||
order={"updatedAt": "desc"},
|
|
||||||
take=limit,
|
|
||||||
skip=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_session_count(user_id: str) -> int:
|
|
||||||
"""Get the total number of chat sessions for a user."""
|
|
||||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
|
||||||
"""Delete a chat session and all its messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The session ID to delete.
|
|
||||||
user_id: If provided, validates that the session belongs to this user
|
|
||||||
before deletion. This prevents unauthorized deletion of other
|
|
||||||
users' sessions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted successfully, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Build typed where clause with optional user_id validation
|
|
||||||
where_clause: ChatSessionWhereInput = {"id": session_id}
|
|
||||||
if user_id is not None:
|
|
||||||
where_clause["userId"] = user_id
|
|
||||||
|
|
||||||
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
|
|
||||||
if result == 0:
|
|
||||||
logger.warning(
|
|
||||||
f"No session deleted for {session_id} "
|
|
||||||
f"(user_id validation: {user_id is not None})"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_session_message_count(session_id: str) -> int:
|
|
||||||
"""Get the number of messages in a chat session."""
|
|
||||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
async def update_tool_message_content(
|
|
||||||
session_id: str,
|
|
||||||
tool_call_id: str,
|
|
||||||
new_content: str,
|
|
||||||
) -> bool:
|
|
||||||
"""Update the content of a tool message in chat history.
|
|
||||||
|
|
||||||
Used by background tasks to update pending operation messages with final results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The chat session ID.
|
|
||||||
tool_call_id: The tool call ID to find the message.
|
|
||||||
new_content: The new content to set.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if a message was updated, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await PrismaChatMessage.prisma().update_many(
|
|
||||||
where={
|
|
||||||
"sessionId": session_id,
|
|
||||||
"toolCallId": tool_call_id,
|
|
||||||
},
|
|
||||||
data={
|
|
||||||
"content": new_content,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if result == 0:
|
|
||||||
logger.warning(
|
|
||||||
f"No message found to update for session {session_id}, "
|
|
||||||
f"tool_call_id {tool_call_id}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to update tool message for session {session_id}, "
|
|
||||||
f"tool_call_id {tool_call_id}: {e}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
@@ -1,29 +1,41 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
import re
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from prisma.models import UserWorkspaceFile
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.copilot import service as chat_service
|
||||||
|
from backend.copilot import stream_registry
|
||||||
from . import service as chat_service
|
from backend.copilot.config import ChatConfig
|
||||||
from . import stream_registry
|
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from backend.copilot.model import (
|
||||||
from .config import ChatConfig
|
ChatMessage,
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
ChatSession,
|
||||||
from .response_model import StreamFinish, StreamHeartbeat
|
append_and_save_message,
|
||||||
from .tools.models import (
|
create_chat_session,
|
||||||
|
delete_chat_session,
|
||||||
|
get_chat_session,
|
||||||
|
get_user_sessions,
|
||||||
|
update_session_title,
|
||||||
|
)
|
||||||
|
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||||
|
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
||||||
|
from backend.copilot.tools.models import (
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
AgentsFoundResponse,
|
AgentsFoundResponse,
|
||||||
|
BlockDetailsResponse,
|
||||||
BlockListResponse,
|
BlockListResponse,
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
@@ -32,17 +44,25 @@ from .tools.models import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ExecutionStartedResponse,
|
ExecutionStartedResponse,
|
||||||
InputValidationErrorResponse,
|
InputValidationErrorResponse,
|
||||||
|
MCPToolOutputResponse,
|
||||||
|
MCPToolsDiscoveredResponse,
|
||||||
NeedLoginResponse,
|
NeedLoginResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
OperationInProgressResponse,
|
|
||||||
OperationPendingResponse,
|
|
||||||
OperationStartedResponse,
|
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
|
SuggestedGoalResponse,
|
||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
|
from backend.copilot.tracking import track_user_message
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
from backend.data.understanding import get_business_understanding
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,6 +91,9 @@ class StreamChatRequest(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
is_user_message: bool = True
|
is_user_message: bool = True
|
||||||
context: dict[str, str] | None = None # {url: str, content: str}
|
context: dict[str, str] | None = None # {url: str, content: str}
|
||||||
|
file_ids: list[str] | None = Field(
|
||||||
|
default=None, max_length=20
|
||||||
|
) # Workspace file IDs attached to this message
|
||||||
|
|
||||||
|
|
||||||
class CreateSessionResponse(BaseModel):
|
class CreateSessionResponse(BaseModel):
|
||||||
@@ -84,10 +107,8 @@ class CreateSessionResponse(BaseModel):
|
|||||||
class ActiveStreamInfo(BaseModel):
|
class ActiveStreamInfo(BaseModel):
|
||||||
"""Information about an active stream for reconnection."""
|
"""Information about an active stream for reconnection."""
|
||||||
|
|
||||||
task_id: str
|
turn_id: str
|
||||||
last_message_id: str # Redis Stream message ID for resumption
|
last_message_id: str # Redis Stream message ID for resumption
|
||||||
operation_id: str # Operation ID for completion tracking
|
|
||||||
tool_name: str # Name of the tool being executed
|
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
@@ -108,6 +129,7 @@ class SessionSummaryResponse(BaseModel):
|
|||||||
created_at: str
|
created_at: str
|
||||||
updated_at: str
|
updated_at: str
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
|
is_processing: bool
|
||||||
|
|
||||||
|
|
||||||
class ListSessionsResponse(BaseModel):
|
class ListSessionsResponse(BaseModel):
|
||||||
@@ -117,12 +139,25 @@ class ListSessionsResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class OperationCompleteRequest(BaseModel):
|
class CancelSessionResponse(BaseModel):
|
||||||
"""Request model for external completion webhook."""
|
"""Response model for the cancel session endpoint."""
|
||||||
|
|
||||||
success: bool
|
cancelled: bool
|
||||||
result: dict | str | None = None
|
reason: str | None = None
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
class UpdateSessionTitleRequest(BaseModel):
|
||||||
|
"""Request model for updating a session's title."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
|
||||||
|
@field_validator("title")
|
||||||
|
@classmethod
|
||||||
|
def title_must_not_be_blank(cls, v: str) -> str:
|
||||||
|
stripped = v.strip()
|
||||||
|
if not stripped:
|
||||||
|
raise ValueError("Title must not be blank")
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
@@ -153,6 +188,28 @@ async def list_sessions(
|
|||||||
"""
|
"""
|
||||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||||
|
|
||||||
|
# Batch-check Redis for active stream status on each session
|
||||||
|
processing_set: set[str] = set()
|
||||||
|
if sessions:
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
pipe = redis.pipeline(transaction=False)
|
||||||
|
for session in sessions:
|
||||||
|
pipe.hget(
|
||||||
|
f"{config.session_meta_prefix}{session.session_id}",
|
||||||
|
"status",
|
||||||
|
)
|
||||||
|
statuses = await pipe.execute()
|
||||||
|
processing_set = {
|
||||||
|
session.session_id
|
||||||
|
for session, st in zip(sessions, statuses)
|
||||||
|
if st == "running"
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
||||||
|
)
|
||||||
|
|
||||||
return ListSessionsResponse(
|
return ListSessionsResponse(
|
||||||
sessions=[
|
sessions=[
|
||||||
SessionSummaryResponse(
|
SessionSummaryResponse(
|
||||||
@@ -160,6 +217,7 @@ async def list_sessions(
|
|||||||
created_at=session.started_at.isoformat(),
|
created_at=session.started_at.isoformat(),
|
||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
title=session.title,
|
title=session.title,
|
||||||
|
is_processing=session.session_id in processing_set,
|
||||||
)
|
)
|
||||||
for session in sessions
|
for session in sessions
|
||||||
],
|
],
|
||||||
@@ -199,6 +257,92 @@ async def create_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/sessions/{session_id}",
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
status_code=204,
|
||||||
|
responses={404: {"description": "Session not found or access denied"}},
|
||||||
|
)
|
||||||
|
async def delete_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Delete a chat session.
|
||||||
|
|
||||||
|
Permanently removes a chat session and all its messages.
|
||||||
|
Only the owner can delete their sessions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session ID to delete.
|
||||||
|
user_id: The authenticated user's ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
204 No Content on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 404 if session not found or not owned by user.
|
||||||
|
"""
|
||||||
|
deleted = await delete_chat_session(session_id, user_id)
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Session {session_id} not found or access denied",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Best-effort cleanup of the E2B sandbox (if any).
|
||||||
|
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
|
||||||
|
e2b_cfg = ChatConfig()
|
||||||
|
if e2b_cfg.e2b_active:
|
||||||
|
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
|
||||||
|
try:
|
||||||
|
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/sessions/{session_id}/title",
|
||||||
|
summary="Update session title",
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
status_code=200,
|
||||||
|
responses={404: {"description": "Session not found or access denied"}},
|
||||||
|
)
|
||||||
|
async def update_session_title_route(
|
||||||
|
session_id: str,
|
||||||
|
request: UpdateSessionTitleRequest,
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Update the title of a chat session.
|
||||||
|
|
||||||
|
Allows the user to rename their chat session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session ID to update.
|
||||||
|
request: Request body containing the new title.
|
||||||
|
user_id: The authenticated user's ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Status of the update.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 404 if session not found or not owned by user.
|
||||||
|
"""
|
||||||
|
success = await update_session_title(session_id, user_id, request.title)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Session {session_id} not found or access denied",
|
||||||
|
)
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}",
|
"/sessions/{session_id}",
|
||||||
)
|
)
|
||||||
@@ -210,7 +354,7 @@ async def get_session(
|
|||||||
Retrieve the details of a specific chat session.
|
Retrieve the details of a specific chat session.
|
||||||
|
|
||||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||||
If there's an active stream for this session, returns the task_id for reconnection.
|
If there's an active stream for this session, returns active_stream info for reconnection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The unique identifier for the desired chat session.
|
session_id: The unique identifier for the desired chat session.
|
||||||
@@ -228,24 +372,21 @@ async def get_session(
|
|||||||
|
|
||||||
# Check if there's an active stream for this session
|
# Check if there's an active stream for this session
|
||||||
active_stream_info = None
|
active_stream_info = None
|
||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
active_session, last_message_id = await stream_registry.get_active_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
if active_task:
|
logger.info(
|
||||||
# Filter out the in-progress assistant message from the session response.
|
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
||||||
# The client will receive the complete assistant response through the SSE
|
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||||
# stream replay instead, preventing duplicate content.
|
)
|
||||||
if messages and messages[-1].get("role") == "assistant":
|
if active_session:
|
||||||
messages = messages[:-1]
|
# Keep the assistant message (including tool_calls) so the frontend can
|
||||||
|
# render the correct tool UI (e.g. CreateAgent with mini game).
|
||||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
# convertChatSessionToUiMessages handles isComplete=false by setting
|
||||||
# Since we filtered out the cached assistant message, the client needs
|
# tool parts without output to state "input-available".
|
||||||
# the full stream to reconstruct the response.
|
|
||||||
active_stream_info = ActiveStreamInfo(
|
active_stream_info = ActiveStreamInfo(
|
||||||
task_id=active_task.task_id,
|
turn_id=active_session.turn_id,
|
||||||
last_message_id="0-0",
|
last_message_id=last_message_id,
|
||||||
operation_id=active_task.operation_id,
|
|
||||||
tool_name=active_task.tool_name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
@@ -258,6 +399,51 @@ async def get_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/sessions/{session_id}/cancel",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
async def cancel_session_task(
|
||||||
|
session_id: str,
|
||||||
|
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||||
|
) -> CancelSessionResponse:
|
||||||
|
"""Cancel the active streaming task for a session.
|
||||||
|
|
||||||
|
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
||||||
|
polls Redis until the task status flips from ``running`` or a timeout
|
||||||
|
(5 s) is reached. Returns only after the cancellation is confirmed.
|
||||||
|
"""
|
||||||
|
await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
|
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
|
||||||
|
if not active_session:
|
||||||
|
return CancelSessionResponse(cancelled=True, reason="no_active_session")
|
||||||
|
|
||||||
|
await enqueue_cancel_task(session_id)
|
||||||
|
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
|
||||||
|
|
||||||
|
# Poll until the executor confirms the task is no longer running.
|
||||||
|
poll_interval = 0.5
|
||||||
|
max_wait = 5.0
|
||||||
|
waited = 0.0
|
||||||
|
while waited < max_wait:
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
waited += poll_interval
|
||||||
|
session_state = await stream_registry.get_session(session_id)
|
||||||
|
if session_state is None or session_state.status != "running":
|
||||||
|
logger.info(
|
||||||
|
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
|
||||||
|
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
|
||||||
|
)
|
||||||
|
return CancelSessionResponse(cancelled=True)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
|
||||||
|
)
|
||||||
|
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
|
||||||
|
return CancelSessionResponse(cancelled=True)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
@@ -275,16 +461,15 @@ async def stream_chat_post(
|
|||||||
- Tool execution results
|
- Tool execution results
|
||||||
|
|
||||||
The AI generation runs in a background task that continues even if the client disconnects.
|
The AI generation runs in a background task that continues even if the client disconnects.
|
||||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
||||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
containing the task_id for reconnection.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -300,10 +485,9 @@ async def stream_chat_post(
|
|||||||
f"user={user_id}, message_len={len(request.message)}",
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
|
await _validate_and_get_session(session_id, user_id)
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
@@ -312,106 +496,95 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Enrich message with file metadata if file_ids are provided.
|
||||||
task_id = str(uuid_module.uuid4())
|
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||||
operation_id = str(uuid_module.uuid4())
|
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||||
log_meta["task_id"] = task_id
|
sanitized_file_ids: list[str] | None = None
|
||||||
|
if request.file_ids and user_id:
|
||||||
|
# Filter to valid UUIDs only to prevent DB abuse
|
||||||
|
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||||
|
|
||||||
task_create_start = time.perf_counter()
|
if valid_ids:
|
||||||
await stream_registry.create_task(
|
workspace = await get_or_create_workspace(user_id)
|
||||||
task_id=task_id,
|
# Batch query instead of N+1
|
||||||
|
files = await UserWorkspaceFile.prisma().find_many(
|
||||||
|
where={
|
||||||
|
"id": {"in": valid_ids},
|
||||||
|
"workspaceId": workspace.id,
|
||||||
|
"isDeleted": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Only keep IDs that actually exist in the user's workspace
|
||||||
|
sanitized_file_ids = [wf.id for wf in files] or None
|
||||||
|
file_lines: list[str] = [
|
||||||
|
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||||
|
for wf in files
|
||||||
|
]
|
||||||
|
if file_lines:
|
||||||
|
files_block = (
|
||||||
|
"\n\n[Attached files]\n"
|
||||||
|
+ "\n".join(file_lines)
|
||||||
|
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||||
|
)
|
||||||
|
request.message += files_block
|
||||||
|
|
||||||
|
# Atomically append user message to session BEFORE creating task to avoid
|
||||||
|
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||||
|
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||||
|
# message loss from concurrent requests.
|
||||||
|
if request.message:
|
||||||
|
message = ChatMessage(
|
||||||
|
role="user" if request.is_user_message else "assistant",
|
||||||
|
content=request.message,
|
||||||
|
)
|
||||||
|
if request.is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
message_length=len(request.message),
|
||||||
|
)
|
||||||
|
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||||
|
await append_and_save_message(session_id, message)
|
||||||
|
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||||
|
|
||||||
|
# Create a task in the stream registry for reconnection support
|
||||||
|
turn_id = str(uuid4())
|
||||||
|
log_meta["turn_id"] = turn_id
|
||||||
|
|
||||||
|
session_create_start = time.perf_counter()
|
||||||
|
await stream_registry.create_session(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
tool_call_id="chat_stream",
|
||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
operation_id=operation_id,
|
turn_id=turn_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||||
async def run_ai_generation():
|
subscribe_from_id = "0-0"
|
||||||
import time as time_module
|
|
||||||
|
|
||||||
gen_start_time = time_module.perf_counter()
|
await enqueue_copilot_turn(
|
||||||
logger.info(
|
session_id=session_id,
|
||||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
user_id=user_id,
|
||||||
extra={"json_fields": log_meta},
|
message=request.message,
|
||||||
)
|
turn_id=turn_id,
|
||||||
first_chunk_time, ttfc = None, None
|
is_user_message=request.is_user_message,
|
||||||
chunk_count = 0
|
context=request.context,
|
||||||
try:
|
file_ids=sanitized_file_ids,
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
)
|
||||||
session_id,
|
|
||||||
request.message,
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
context=request.context,
|
|
||||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
|
||||||
):
|
|
||||||
chunk_count += 1
|
|
||||||
if first_chunk_time is None:
|
|
||||||
first_chunk_time = time_module.perf_counter()
|
|
||||||
ttfc = first_chunk_time - gen_start_time
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
"time_to_first_chunk_ms": ttfc * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
|
||||||
|
|
||||||
gen_end_time = time_module.perf_counter()
|
|
||||||
total_time = (gen_end_time - gen_start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
|
||||||
f"task={task_id}, session={session_id}, "
|
|
||||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"time_to_first_chunk_ms": (
|
|
||||||
ttfc * 1000 if ttfc is not None else None
|
|
||||||
),
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
|
||||||
except Exception as e:
|
|
||||||
elapsed = time_module.perf_counter() - gen_start_time
|
|
||||||
logger.error(
|
|
||||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed * 1000,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
|
||||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -421,7 +594,7 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
event_gen_start = time_module.perf_counter()
|
event_gen_start = time_module.perf_counter()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
|
||||||
f"user={user_id}",
|
f"user={user_id}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
@@ -429,11 +602,12 @@ async def stream_chat_post(
|
|||||||
first_chunk_yielded = False
|
first_chunk_yielded = False
|
||||||
chunks_yielded = 0
|
chunks_yielded = 0
|
||||||
try:
|
try:
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
# Subscribe from the position we captured before enqueuing
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
# This avoids replaying old messages while catching all new ones
|
||||||
task_id=task_id,
|
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||||
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
last_message_id="0-0", # Get all messages from the beginning
|
last_message_id=subscribe_from_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if subscriber_queue is None:
|
if subscriber_queue is None:
|
||||||
@@ -448,7 +622,7 @@ async def stream_chat_post(
|
|||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||||
chunks_yielded += 1
|
chunks_yielded += 1
|
||||||
|
|
||||||
if not first_chunk_yielded:
|
if not first_chunk_yielded:
|
||||||
@@ -506,23 +680,29 @@ async def stream_chat_post(
|
|||||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# Surface error to frontend so it doesn't appear stuck
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="stream_error",
|
||||||
|
).to_sse()
|
||||||
|
yield StreamFinish().to_sse()
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_task(
|
await stream_registry.unsubscribe_from_session(
|
||||||
task_id, subscriber_queue
|
session_id, subscriber_queue
|
||||||
)
|
)
|
||||||
except Exception as unsub_err:
|
except Exception as unsub_err:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
@@ -569,17 +749,21 @@ async def resume_session_stream(
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
active_session, last_message_id = await stream_registry.get_active_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not active_task:
|
if not active_session:
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
# Always replay from the beginning ("0-0") on resume.
|
||||||
task_id=active_task.task_id,
|
# We can't use last_message_id because it's the latest ID in the backend
|
||||||
|
# stream, not the latest the frontend received — the gap causes lost
|
||||||
|
# messages. The frontend deduplicates replayed content.
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_session(
|
||||||
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
last_message_id="0-0",
|
||||||
)
|
)
|
||||||
|
|
||||||
if subscriber_queue is None:
|
if subscriber_queue is None:
|
||||||
@@ -591,7 +775,7 @@ async def resume_session_stream(
|
|||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
||||||
if chunk_count < 3:
|
if chunk_count < 3:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Resume stream chunk",
|
"Resume stream chunk",
|
||||||
@@ -615,12 +799,12 @@ async def resume_session_stream(
|
|||||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_task(
|
await stream_registry.unsubscribe_from_session(
|
||||||
active_task.task_id, subscriber_queue
|
session_id, subscriber_queue
|
||||||
)
|
)
|
||||||
except Exception as unsub_err:
|
except Exception as unsub_err:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -648,7 +832,6 @@ async def resume_session_stream(
|
|||||||
@router.patch(
|
@router.patch(
|
||||||
"/sessions/{session_id}/assign-user",
|
"/sessions/{session_id}/assign-user",
|
||||||
dependencies=[Security(auth.requires_user)],
|
dependencies=[Security(auth.requires_user)],
|
||||||
status_code=200,
|
|
||||||
)
|
)
|
||||||
async def session_assign_user(
|
async def session_assign_user(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
@@ -671,229 +854,34 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
# ========== Task Streaming (SSE Reconnection) ==========
|
# ========== Suggested Prompts ==========
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestedPromptsResponse(BaseModel):
|
||||||
|
"""Response model for user-specific suggested prompts."""
|
||||||
|
|
||||||
|
prompts: list[str]
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/tasks/{task_id}/stream",
|
"/suggested-prompts",
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
)
|
)
|
||||||
async def stream_task(
|
async def get_suggested_prompts(
|
||||||
task_id: str,
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
) -> SuggestedPromptsResponse:
|
||||||
last_message_id: str = Query(
|
|
||||||
default="0-0",
|
|
||||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Reconnect to a long-running task's SSE stream.
|
Get LLM-generated suggested prompts for the authenticated user.
|
||||||
|
|
||||||
When a long-running operation (like agent generation) starts, the client
|
Returns personalized quick-action prompts based on the user's
|
||||||
receives a task_id. If the connection drops, the client can reconnect
|
business understanding. Returns an empty list if no custom prompts
|
||||||
using this endpoint to resume receiving updates.
|
are available.
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: The task ID from the operation_started response.
|
|
||||||
user_id: Authenticated user ID for ownership validation.
|
|
||||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
|
||||||
"""
|
"""
|
||||||
# Check task existence and expiry before subscribing
|
understanding = await get_business_understanding(user_id)
|
||||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
if understanding is None:
|
||||||
|
return SuggestedPromptsResponse(prompts=[])
|
||||||
|
|
||||||
if error_code == "TASK_EXPIRED":
|
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
||||||
raise HTTPException(
|
|
||||||
status_code=410,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_EXPIRED",
|
|
||||||
"message": "This operation has expired. Please try again.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if error_code == "TASK_NOT_FOUND":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_NOT_FOUND",
|
|
||||||
"message": f"Task {task_id} not found.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate ownership if task has an owner
|
|
||||||
if task and task.user_id and user_id != task.user_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail={
|
|
||||||
"code": "ACCESS_DENIED",
|
|
||||||
"message": "You do not have access to this task.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get subscriber queue from stream registry
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
|
||||||
task_id=task_id,
|
|
||||||
user_id=user_id,
|
|
||||||
last_message_id=last_message_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_NOT_FOUND",
|
|
||||||
"message": f"Task {task_id} not found or access denied.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# Wait for next chunk with timeout for heartbeats
|
|
||||||
chunk = await asyncio.wait_for(
|
|
||||||
subscriber_queue.get(), timeout=heartbeat_interval
|
|
||||||
)
|
|
||||||
yield chunk.to_sse()
|
|
||||||
|
|
||||||
# Check for finish signal
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Send heartbeat to keep connection alive
|
|
||||||
yield StreamHeartbeat().to_sse()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
|
||||||
finally:
|
|
||||||
# Unsubscribe when client disconnects or stream ends
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
|
||||||
except Exception as unsub_err:
|
|
||||||
logger.error(
|
|
||||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
event_generator(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/tasks/{task_id}",
|
|
||||||
)
|
|
||||||
async def get_task_status(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Get the status of a long-running task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: The task ID to check.
|
|
||||||
user_id: Authenticated user ID for ownership validation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If task_id is not found or user doesn't have access.
|
|
||||||
"""
|
|
||||||
task = await stream_registry.get_task(task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
raise NotFoundError(f"Task {task_id} not found.")
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task.user_id and user_id != task.user_id:
|
|
||||||
raise NotFoundError(f"Task {task_id} not found.")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_id": task.task_id,
|
|
||||||
"session_id": task.session_id,
|
|
||||||
"status": task.status,
|
|
||||||
"tool_name": task.tool_name,
|
|
||||||
"operation_id": task.operation_id,
|
|
||||||
"created_at": task.created_at.isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== External Completion Webhook ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/operations/{operation_id}/complete",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
async def complete_operation(
|
|
||||||
operation_id: str,
|
|
||||||
request: OperationCompleteRequest,
|
|
||||||
x_api_key: str | None = Header(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
External completion webhook for long-running operations.
|
|
||||||
|
|
||||||
Called by Agent Generator (or other services) when an operation completes.
|
|
||||||
This triggers the stream registry to publish completion and continue LLM generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: The operation ID to complete.
|
|
||||||
request: Completion payload with success status and result/error.
|
|
||||||
x_api_key: Internal API key for authentication.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Status of the completion.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If API key is invalid or operation not found.
|
|
||||||
"""
|
|
||||||
# Validate internal API key - reject if not configured or invalid
|
|
||||||
if not config.internal_api_key:
|
|
||||||
logger.error(
|
|
||||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail="Webhook not available: internal API key not configured",
|
|
||||||
)
|
|
||||||
if x_api_key != config.internal_api_key:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
||||||
|
|
||||||
# Find task by operation_id
|
|
||||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
|
||||||
if task is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Operation {operation_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Received completion webhook for operation {operation_id} "
|
|
||||||
f"(task_id={task.task_id}, success={request.success})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.success:
|
|
||||||
await process_operation_success(task, request.result)
|
|
||||||
else:
|
|
||||||
await process_operation_failure(task, request.error)
|
|
||||||
|
|
||||||
return {"status": "ok", "task_id": task.task_id}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Configuration ==========
|
# ========== Configuration ==========
|
||||||
@@ -970,13 +958,14 @@ ToolResponseUnion = (
|
|||||||
| AgentPreviewResponse
|
| AgentPreviewResponse
|
||||||
| AgentSavedResponse
|
| AgentSavedResponse
|
||||||
| ClarificationNeededResponse
|
| ClarificationNeededResponse
|
||||||
|
| SuggestedGoalResponse
|
||||||
| BlockListResponse
|
| BlockListResponse
|
||||||
|
| BlockDetailsResponse
|
||||||
| BlockOutputResponse
|
| BlockOutputResponse
|
||||||
| DocSearchResultsResponse
|
| DocSearchResultsResponse
|
||||||
| DocPageResponse
|
| DocPageResponse
|
||||||
| OperationStartedResponse
|
| MCPToolsDiscoveredResponse
|
||||||
| OperationPendingResponse
|
| MCPToolOutputResponse
|
||||||
| OperationInProgressResponse
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,310 @@
|
|||||||
|
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
import pytest
|
||||||
|
import pytest_mock
|
||||||
|
|
||||||
|
from backend.api.features.chat import routes as chat_routes
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(chat_routes.router)
|
||||||
|
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_app_auth(mock_jwt_user):
|
||||||
|
"""Setup auth overrides for all tests in this module"""
|
||||||
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||||
|
|
||||||
|
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_update_session_title(
|
||||||
|
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
||||||
|
):
|
||||||
|
"""Mock update_session_title."""
|
||||||
|
return mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.update_session_title",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=success,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Update title: success ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_title_success(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
mock_update = _mock_update_session_title(mocker, success=True)
|
||||||
|
|
||||||
|
response = client.patch(
|
||||||
|
"/sessions/sess-1/title",
|
||||||
|
json={"title": "My project"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
|
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_title_trims_whitespace(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
mock_update = _mock_update_session_title(mocker, success=True)
|
||||||
|
|
||||||
|
response = client.patch(
|
||||||
|
"/sessions/sess-1/title",
|
||||||
|
json={"title": " trimmed "},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_title_blank_rejected(
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Whitespace-only titles must be rejected before hitting the DB."""
|
||||||
|
response = client.patch(
|
||||||
|
"/sessions/sess-1/title",
|
||||||
|
json={"title": " "},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_title_empty_rejected(
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
response = client.patch(
|
||||||
|
"/sessions/sess-1/title",
|
||||||
|
json={"title": ""},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Update title: session not found or wrong user → 404 ──────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_title_not_found(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
_mock_update_session_title(mocker, success=False)
|
||||||
|
|
||||||
|
response = client.patch(
|
||||||
|
"/sessions/sess-1/title",
|
||||||
|
json={"title": "New name"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_chat_rejects_too_many_file_ids():
|
||||||
|
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
|
||||||
|
response = client.post(
|
||||||
|
"/sessions/sess-1/stream",
|
||||||
|
json={
|
||||||
|
"message": "hello",
|
||||||
|
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||||
|
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.chat.routes._validate_and_get_session",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.append_and_save_message",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
mock_registry = mocker.MagicMock()
|
||||||
|
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.stream_registry",
|
||||||
|
mock_registry,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.track_user_message",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||||
|
_mock_stream_internals(mocker)
|
||||||
|
# Patch workspace lookup as imported by the routes module
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||||
|
return_value=type("W", (), {"id": "ws-1"})(),
|
||||||
|
)
|
||||||
|
mock_prisma = mocker.MagicMock()
|
||||||
|
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||||
|
mocker.patch(
|
||||||
|
"prisma.models.UserWorkspaceFile.prisma",
|
||||||
|
return_value=mock_prisma,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/sessions/sess-1/stream",
|
||||||
|
json={
|
||||||
|
"message": "hello",
|
||||||
|
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Should get past validation — 200 streaming response expected
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Non-UUID strings in file_ids should be silently filtered out
|
||||||
|
and NOT passed to the database query."""
|
||||||
|
_mock_stream_internals(mocker)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||||
|
return_value=type("W", (), {"id": "ws-1"})(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_prisma = mocker.MagicMock()
|
||||||
|
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||||
|
mocker.patch(
|
||||||
|
"prisma.models.UserWorkspaceFile.prisma",
|
||||||
|
return_value=mock_prisma,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||||
|
client.post(
|
||||||
|
"/sessions/sess-1/stream",
|
||||||
|
json={
|
||||||
|
"message": "hello",
|
||||||
|
"file_ids": [
|
||||||
|
valid_id,
|
||||||
|
"not-a-uuid",
|
||||||
|
"../../../etc/passwd",
|
||||||
|
"",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# The find_many call should only receive the one valid UUID
|
||||||
|
mock_prisma.find_many.assert_called_once()
|
||||||
|
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||||
|
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||||
|
"""The batch query should scope to the user's workspace."""
|
||||||
|
_mock_stream_internals(mocker)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||||
|
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_prisma = mocker.MagicMock()
|
||||||
|
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||||
|
mocker.patch(
|
||||||
|
"prisma.models.UserWorkspaceFile.prisma",
|
||||||
|
return_value=mock_prisma,
|
||||||
|
)
|
||||||
|
|
||||||
|
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||||
|
client.post(
|
||||||
|
"/sessions/sess-1/stream",
|
||||||
|
json={"message": "hi", "file_ids": [fid]},
|
||||||
|
)
|
||||||
|
|
||||||
|
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||||
|
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
||||||
|
assert call_kwargs["where"]["isDeleted"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_get_business_understanding(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
*,
|
||||||
|
return_value=None,
|
||||||
|
):
|
||||||
|
"""Mock get_business_understanding."""
|
||||||
|
return mocker.patch(
|
||||||
|
"backend.api.features.chat.routes.get_business_understanding",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=return_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggested_prompts_returns_prompts(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""User with understanding and prompts gets them back."""
|
||||||
|
mock_understanding = MagicMock()
|
||||||
|
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
||||||
|
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||||
|
|
||||||
|
response = client.get("/suggested-prompts")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggested_prompts_no_understanding(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""User with no understanding gets empty list."""
|
||||||
|
_mock_get_business_understanding(mocker, return_value=None)
|
||||||
|
|
||||||
|
response = client.get("/suggested-prompts")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"prompts": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_suggested_prompts_empty_prompts(
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
test_user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""User with understanding but no prompts gets empty list."""
|
||||||
|
mock_understanding = MagicMock()
|
||||||
|
mock_understanding.suggested_prompts = []
|
||||||
|
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
||||||
|
|
||||||
|
response = client.get("/suggested-prompts")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"prompts": []}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,82 +0,0 @@
|
|||||||
import logging
|
|
||||||
from os import getenv
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from . import service as chat_service
|
|
||||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
|
||||||
from .response_model import (
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
|
||||||
"""
|
|
||||||
Test the stream_chat_completion function.
|
|
||||||
"""
|
|
||||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
|
||||||
|
|
||||||
session = await create_chat_session(test_user_id)
|
|
||||||
|
|
||||||
has_errors = False
|
|
||||||
has_ended = False
|
|
||||||
assistant_message = ""
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session.session_id, "Hello, how are you?", user_id=session.user_id
|
|
||||||
):
|
|
||||||
logger.info(chunk)
|
|
||||||
if isinstance(chunk, StreamError):
|
|
||||||
has_errors = True
|
|
||||||
if isinstance(chunk, StreamTextDelta):
|
|
||||||
assistant_message += chunk.delta
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
has_ended = True
|
|
||||||
|
|
||||||
assert has_ended, "Chat completion did not end"
|
|
||||||
assert not has_errors, "Error occurred while streaming chat completion"
|
|
||||||
assert assistant_message, "Assistant message is empty"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
|
||||||
"""
|
|
||||||
Test the stream_chat_completion function.
|
|
||||||
"""
|
|
||||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
|
||||||
|
|
||||||
session = await create_chat_session(test_user_id)
|
|
||||||
session = await upsert_chat_session(session)
|
|
||||||
|
|
||||||
has_errors = False
|
|
||||||
has_ended = False
|
|
||||||
had_tool_calls = False
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session.session_id,
|
|
||||||
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
|
|
||||||
user_id=session.user_id,
|
|
||||||
):
|
|
||||||
logger.info(chunk)
|
|
||||||
if isinstance(chunk, StreamError):
|
|
||||||
has_errors = True
|
|
||||||
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
has_ended = True
|
|
||||||
if isinstance(chunk, StreamToolOutputAvailable):
|
|
||||||
had_tool_calls = True
|
|
||||||
|
|
||||||
assert has_ended, "Chat completion did not end"
|
|
||||||
assert not has_errors, "Error occurred while streaming chat completion"
|
|
||||||
assert had_tool_calls, "Tool calls did not occur"
|
|
||||||
session = await get_chat_session(session.session_id)
|
|
||||||
assert session, "Session not found"
|
|
||||||
assert session.usage, "Usage is empty"
|
|
||||||
@@ -1,498 +0,0 @@
|
|||||||
"""External Agent Generator service client.
|
|
||||||
|
|
||||||
This module provides a client for communicating with the external Agent Generator
|
|
||||||
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
|
||||||
will delegate to the external service instead of using the built-in LLM-based implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_error_response(
|
|
||||||
error_message: str,
|
|
||||||
error_type: str = "unknown",
|
|
||||||
details: dict[str, Any] | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Create a standardized error response dict.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
error_message: Human-readable error message
|
|
||||||
error_type: Machine-readable error type
|
|
||||||
details: Optional additional error details
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Error dict with type="error" and error details
|
|
||||||
"""
|
|
||||||
response: dict[str, Any] = {
|
|
||||||
"type": "error",
|
|
||||||
"error": error_message,
|
|
||||||
"error_type": error_type,
|
|
||||||
}
|
|
||||||
if details:
|
|
||||||
response["details"] = details
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
|
||||||
"""Classify an HTTP error into error_type and message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
e: The HTTP status error
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (error_type, error_message)
|
|
||||||
"""
|
|
||||||
status = e.response.status_code
|
|
||||||
if status == 429:
|
|
||||||
return "rate_limit", f"Agent Generator rate limited: {e}"
|
|
||||||
elif status == 503:
|
|
||||||
return "service_unavailable", f"Agent Generator unavailable: {e}"
|
|
||||||
elif status == 504 or status == 408:
|
|
||||||
return "timeout", f"Agent Generator timed out: {e}"
|
|
||||||
else:
|
|
||||||
return "http_error", f"HTTP error calling Agent Generator: {e}"
|
|
||||||
|
|
||||||
|
|
||||||
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
|
||||||
"""Classify a request error into error_type and message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
e: The request error
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (error_type, error_message)
|
|
||||||
"""
|
|
||||||
error_str = str(e).lower()
|
|
||||||
if "timeout" in error_str or "timed out" in error_str:
|
|
||||||
return "timeout", f"Agent Generator request timed out: {e}"
|
|
||||||
elif "connect" in error_str:
|
|
||||||
return "connection_error", f"Could not connect to Agent Generator: {e}"
|
|
||||||
else:
|
|
||||||
return "request_error", f"Request error calling Agent Generator: {e}"
|
|
||||||
|
|
||||||
|
|
||||||
_client: httpx.AsyncClient | None = None
|
|
||||||
_settings: Settings | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_settings() -> Settings:
|
|
||||||
"""Get or create settings singleton."""
|
|
||||||
global _settings
|
|
||||||
if _settings is None:
|
|
||||||
_settings = Settings()
|
|
||||||
return _settings
|
|
||||||
|
|
||||||
|
|
||||||
def is_external_service_configured() -> bool:
|
|
||||||
"""Check if external Agent Generator service is configured."""
|
|
||||||
settings = _get_settings()
|
|
||||||
return bool(settings.config.agentgenerator_host)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_base_url() -> str:
|
|
||||||
"""Get the base URL for the external service."""
|
|
||||||
settings = _get_settings()
|
|
||||||
host = settings.config.agentgenerator_host
|
|
||||||
port = settings.config.agentgenerator_port
|
|
||||||
return f"http://{host}:{port}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client() -> httpx.AsyncClient:
|
|
||||||
"""Get or create the HTTP client for the external service."""
|
|
||||||
global _client
|
|
||||||
if _client is None:
|
|
||||||
settings = _get_settings()
|
|
||||||
_client = httpx.AsyncClient(
|
|
||||||
base_url=_get_base_url(),
|
|
||||||
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
|
||||||
)
|
|
||||||
return _client
|
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal_external(
|
|
||||||
description: str,
|
|
||||||
context: str = "",
|
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Call the external service to decompose a goal.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
description: Natural language goal description
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
library_agents: User's library agents available for sub-agent composition
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with either:
|
|
||||||
- {"type": "clarifying_questions", "questions": [...]}
|
|
||||||
- {"type": "instructions", "steps": [...]}
|
|
||||||
- {"type": "unachievable_goal", ...}
|
|
||||||
- {"type": "vague_goal", ...}
|
|
||||||
- {"type": "error", "error": "...", "error_type": "..."} on error
|
|
||||||
Or None on unexpected error
|
|
||||||
"""
|
|
||||||
client = _get_client()
|
|
||||||
|
|
||||||
if context:
|
|
||||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
|
||||||
|
|
||||||
payload: dict[str, Any] = {"description": description}
|
|
||||||
if library_agents:
|
|
||||||
payload["library_agents"] = library_agents
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post("/api/decompose-description", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not data.get("success"):
|
|
||||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
|
||||||
error_type = data.get("error_type", "unknown")
|
|
||||||
logger.error(
|
|
||||||
f"Agent Generator decomposition failed: {error_msg} "
|
|
||||||
f"(type: {error_type})"
|
|
||||||
)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
|
|
||||||
# Map the response to the expected format
|
|
||||||
response_type = data.get("type")
|
|
||||||
if response_type == "instructions":
|
|
||||||
return {"type": "instructions", "steps": data.get("steps", [])}
|
|
||||||
elif response_type == "clarifying_questions":
|
|
||||||
return {
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": data.get("questions", []),
|
|
||||||
}
|
|
||||||
elif response_type == "unachievable_goal":
|
|
||||||
return {
|
|
||||||
"type": "unachievable_goal",
|
|
||||||
"reason": data.get("reason"),
|
|
||||||
"suggested_goal": data.get("suggested_goal"),
|
|
||||||
}
|
|
||||||
elif response_type == "vague_goal":
|
|
||||||
return {
|
|
||||||
"type": "vague_goal",
|
|
||||||
"suggested_goal": data.get("suggested_goal"),
|
|
||||||
}
|
|
||||||
elif response_type == "error":
|
|
||||||
# Pass through error from the service
|
|
||||||
return _create_error_response(
|
|
||||||
data.get("error", "Unknown error"),
|
|
||||||
data.get("error_type", "unknown"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Unknown response type from external service: {response_type}"
|
|
||||||
)
|
|
||||||
return _create_error_response(
|
|
||||||
f"Unknown response type from Agent Generator: {response_type}",
|
|
||||||
"invalid_response",
|
|
||||||
)
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_type, error_msg = _classify_http_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
error_type, error_msg = _classify_request_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, "unexpected_error")
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_external(
|
|
||||||
instructions: dict[str, Any],
|
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Call the external service to generate an agent from instructions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
instructions: Structured instructions from decompose_goal
|
|
||||||
library_agents: User's library agents available for sub-agent composition
|
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
|
||||||
"""
|
|
||||||
client = _get_client()
|
|
||||||
|
|
||||||
# Build request payload
|
|
||||||
payload: dict[str, Any] = {"instructions": instructions}
|
|
||||||
if library_agents:
|
|
||||||
payload["library_agents"] = library_agents
|
|
||||||
if operation_id and task_id:
|
|
||||||
payload["operation_id"] = operation_id
|
|
||||||
payload["task_id"] = task_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post("/api/generate-agent", json=payload)
|
|
||||||
|
|
||||||
# Handle 202 Accepted for async processing
|
|
||||||
if response.status_code == 202:
|
|
||||||
logger.info(
|
|
||||||
f"Agent Generator accepted async request "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "accepted",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not data.get("success"):
|
|
||||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
|
||||||
error_type = data.get("error_type", "unknown")
|
|
||||||
logger.error(
|
|
||||||
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
|
||||||
)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
|
|
||||||
return data.get("agent_json")
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_type, error_msg = _classify_http_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
error_type, error_msg = _classify_request_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, "unexpected_error")
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch_external(
|
|
||||||
update_request: str,
|
|
||||||
current_agent: dict[str, Any],
|
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
update_request: Natural language description of changes
|
|
||||||
current_agent: Current agent JSON
|
|
||||||
library_agents: User's library agents available for sub-agent composition
|
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
|
||||||
"""
|
|
||||||
client = _get_client()
|
|
||||||
|
|
||||||
# Build request payload
|
|
||||||
payload: dict[str, Any] = {
|
|
||||||
"update_request": update_request,
|
|
||||||
"current_agent_json": current_agent,
|
|
||||||
}
|
|
||||||
if library_agents:
|
|
||||||
payload["library_agents"] = library_agents
|
|
||||||
if operation_id and task_id:
|
|
||||||
payload["operation_id"] = operation_id
|
|
||||||
payload["task_id"] = task_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post("/api/update-agent", json=payload)
|
|
||||||
|
|
||||||
# Handle 202 Accepted for async processing
|
|
||||||
if response.status_code == 202:
|
|
||||||
logger.info(
|
|
||||||
f"Agent Generator accepted async update request "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "accepted",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not data.get("success"):
|
|
||||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
|
||||||
error_type = data.get("error_type", "unknown")
|
|
||||||
logger.error(
|
|
||||||
f"Agent Generator patch generation failed: {error_msg} "
|
|
||||||
f"(type: {error_type})"
|
|
||||||
)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
|
|
||||||
# Check if it's clarifying questions
|
|
||||||
if data.get("type") == "clarifying_questions":
|
|
||||||
return {
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": data.get("questions", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if it's an error passed through
|
|
||||||
if data.get("type") == "error":
|
|
||||||
return _create_error_response(
|
|
||||||
data.get("error", "Unknown error"),
|
|
||||||
data.get("error_type", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Otherwise return the updated agent JSON
|
|
||||||
return data.get("agent_json")
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_type, error_msg = _classify_http_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
error_type, error_msg = _classify_request_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, "unexpected_error")
|
|
||||||
|
|
||||||
|
|
||||||
async def customize_template_external(
|
|
||||||
template_agent: dict[str, Any],
|
|
||||||
modification_request: str,
|
|
||||||
context: str = "",
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Call the external service to customize a template/marketplace agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template_agent: The template agent JSON to customize
|
|
||||||
modification_request: Natural language description of customizations
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
|
||||||
"""
|
|
||||||
client = _get_client()
|
|
||||||
|
|
||||||
request = modification_request
|
|
||||||
if context:
|
|
||||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
|
||||||
|
|
||||||
payload: dict[str, Any] = {
|
|
||||||
"template_agent_json": template_agent,
|
|
||||||
"modification_request": request,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post("/api/template-modification", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not data.get("success"):
|
|
||||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
|
||||||
error_type = data.get("error_type", "unknown")
|
|
||||||
logger.error(
|
|
||||||
f"Agent Generator template customization failed: {error_msg} "
|
|
||||||
f"(type: {error_type})"
|
|
||||||
)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
|
|
||||||
# Check if it's clarifying questions
|
|
||||||
if data.get("type") == "clarifying_questions":
|
|
||||||
return {
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": data.get("questions", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if it's an error passed through
|
|
||||||
if data.get("type") == "error":
|
|
||||||
return _create_error_response(
|
|
||||||
data.get("error", "Unknown error"),
|
|
||||||
data.get("error_type", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Otherwise return the customized agent JSON
|
|
||||||
return data.get("agent_json")
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_type, error_msg = _classify_http_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
error_type, error_msg = _classify_request_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, "unexpected_error")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
|
||||||
"""Get available blocks from the external service.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of block info dicts or None on error
|
|
||||||
"""
|
|
||||||
client = _get_client()
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.get("/api/blocks")
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not data.get("success"):
|
|
||||||
logger.error("External service returned error getting blocks")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return data.get("blocks", [])
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
logger.error(f"HTTP error getting blocks from external service: {e}")
|
|
||||||
return None
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.error(f"Request error getting blocks from external service: {e}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def health_check() -> bool:
|
|
||||||
"""Check if the external service is healthy.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if healthy, False otherwise
|
|
||||||
"""
|
|
||||||
if not is_external_service_configured():
|
|
||||||
return False
|
|
||||||
|
|
||||||
client = _get_client()
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.get("/health")
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"External agent generator health check failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def close_client() -> None:
|
|
||||||
"""Close the HTTP client."""
|
|
||||||
global _client
|
|
||||||
if _client is not None:
|
|
||||||
await _client.aclose()
|
|
||||||
_client = None
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
"""Base classes and shared utilities for chat tools."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
|
||||||
|
|
||||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTool:
|
|
||||||
"""Base class for all chat tools."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
"""Tool name for OpenAI function calling."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
"""Tool description for OpenAI."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
"""Tool parameters schema for OpenAI."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
"""Whether this tool requires authentication."""
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_long_running(self) -> bool:
|
|
||||||
"""Whether this tool is long-running and should execute in background.
|
|
||||||
|
|
||||||
Long-running tools (like agent generation) are executed via background
|
|
||||||
tasks to survive SSE disconnections. The result is persisted to chat
|
|
||||||
history and visible when the user refreshes.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
|
||||||
"""Convert to OpenAI tool format."""
|
|
||||||
return ChatCompletionToolParam(
|
|
||||||
type="function",
|
|
||||||
function={
|
|
||||||
"name": self.name,
|
|
||||||
"description": self.description,
|
|
||||||
"parameters": self.parameters,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
tool_call_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> StreamToolOutputAvailable:
|
|
||||||
"""Execute the tool with authentication check.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID (may be anonymous like "anon_123")
|
|
||||||
session_id: Chat session ID
|
|
||||||
**kwargs: Tool-specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pydantic response object
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self.requires_auth and not user_id:
|
|
||||||
logger.error(
|
|
||||||
f"Attempted tool call for {self.name} but user not authenticated"
|
|
||||||
)
|
|
||||||
return StreamToolOutputAvailable(
|
|
||||||
toolCallId=tool_call_id,
|
|
||||||
toolName=self.name,
|
|
||||||
output=NeedLoginResponse(
|
|
||||||
message=f"Please sign in to use {self.name}",
|
|
||||||
session_id=session.session_id,
|
|
||||||
).model_dump_json(),
|
|
||||||
success=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await self._execute(user_id, session, **kwargs)
|
|
||||||
return StreamToolOutputAvailable(
|
|
||||||
toolCallId=tool_call_id,
|
|
||||||
toolName=self.name,
|
|
||||||
output=result.model_dump_json(),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
|
||||||
return StreamToolOutputAvailable(
|
|
||||||
toolCallId=tool_call_id,
|
|
||||||
toolName=self.name,
|
|
||||||
output=ErrorResponse(
|
|
||||||
message=f"An error occurred while executing {self.name}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session.session_id,
|
|
||||||
).model_dump_json(),
|
|
||||||
success=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
"""Internal execution logic to be implemented by subclasses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID (authenticated or anonymous)
|
|
||||||
session_id: Chat session ID
|
|
||||||
**kwargs: Tool-specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pydantic response object
|
|
||||||
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
@@ -1,335 +0,0 @@
|
|||||||
"""CreateAgentTool - Creates agents from natural language descriptions."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
|
|
||||||
from .agent_generator import (
|
|
||||||
AgentGeneratorNotConfiguredError,
|
|
||||||
decompose_goal,
|
|
||||||
enrich_library_agents_from_steps,
|
|
||||||
generate_agent,
|
|
||||||
get_all_relevant_agents_for_generation,
|
|
||||||
get_user_message_for_error,
|
|
||||||
save_agent_to_library,
|
|
||||||
)
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import (
|
|
||||||
AgentPreviewResponse,
|
|
||||||
AgentSavedResponse,
|
|
||||||
AsyncProcessingResponse,
|
|
||||||
ClarificationNeededResponse,
|
|
||||||
ClarifyingQuestion,
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CreateAgentTool(BaseTool):
|
|
||||||
"""Tool for creating agents from natural language descriptions."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "create_agent"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Create a new agent workflow from a natural language description. "
|
|
||||||
"First generates a preview, then saves to library if save=true."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_long_running(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"description": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Natural language description of what the agent should do. "
|
|
||||||
"Be specific about inputs, outputs, and the workflow steps."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"context": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Additional context or answers to previous clarifying questions. "
|
|
||||||
"Include any preferences or constraints mentioned by the user."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"save": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"Whether to save the agent to the user's library. "
|
|
||||||
"Default is true. Set to false for preview only."
|
|
||||||
),
|
|
||||||
"default": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["description"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
"""Execute the create_agent tool.
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Decompose the description into steps (may return clarifying questions)
|
|
||||||
2. Generate agent JSON (external service handles fixing and validation)
|
|
||||||
3. Preview or save based on the save parameter
|
|
||||||
"""
|
|
||||||
description = kwargs.get("description", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
|
||||||
operation_id = kwargs.get("_operation_id")
|
|
||||||
task_id = kwargs.get("_task_id")
|
|
||||||
|
|
||||||
if not description:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide a description of what the agent should do.",
|
|
||||||
error="Missing description parameter",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
library_agents = None
|
|
||||||
if user_id:
|
|
||||||
try:
|
|
||||||
library_agents = await get_all_relevant_agents_for_generation(
|
|
||||||
user_id=user_id,
|
|
||||||
search_query=description,
|
|
||||||
include_marketplace=True,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to fetch library agents: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
decomposition_result = await decompose_goal(
|
|
||||||
description, context, library_agents
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Agent generation is not available. "
|
|
||||||
"The Agent Generator service is not configured."
|
|
||||||
),
|
|
||||||
error="service_not_configured",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decomposition_result is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
|
||||||
error="decomposition_failed",
|
|
||||||
details={"description": description[:100]},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decomposition_result.get("type") == "error":
|
|
||||||
error_msg = decomposition_result.get("error", "Unknown error")
|
|
||||||
error_type = decomposition_result.get("error_type", "unknown")
|
|
||||||
user_message = get_user_message_for_error(
|
|
||||||
error_type,
|
|
||||||
operation="analyze the goal",
|
|
||||||
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=user_message,
|
|
||||||
error=f"decomposition_failed:{error_type}",
|
|
||||||
details={
|
|
||||||
"description": description[:100],
|
|
||||||
"service_error": error_msg,
|
|
||||||
"error_type": error_type,
|
|
||||||
},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decomposition_result.get("type") == "clarifying_questions":
|
|
||||||
questions = decomposition_result.get("questions", [])
|
|
||||||
return ClarificationNeededResponse(
|
|
||||||
message=(
|
|
||||||
"I need some more information to create this agent. "
|
|
||||||
"Please answer the following questions:"
|
|
||||||
),
|
|
||||||
questions=[
|
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decomposition_result.get("type") == "unachievable_goal":
|
|
||||||
suggested = decomposition_result.get("suggested_goal", "")
|
|
||||||
reason = decomposition_result.get("reason", "")
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"This goal cannot be accomplished with the available blocks. "
|
|
||||||
f"{reason} "
|
|
||||||
f"Suggestion: {suggested}"
|
|
||||||
),
|
|
||||||
error="unachievable_goal",
|
|
||||||
details={"suggested_goal": suggested, "reason": reason},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decomposition_result.get("type") == "vague_goal":
|
|
||||||
suggested = decomposition_result.get("suggested_goal", "")
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"The goal is too vague to create a specific workflow. "
|
|
||||||
f"Suggestion: {suggested}"
|
|
||||||
),
|
|
||||||
error="vague_goal",
|
|
||||||
details={"suggested_goal": suggested},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_id and library_agents is not None:
|
|
||||||
try:
|
|
||||||
library_agents = await enrich_library_agents_from_steps(
|
|
||||||
user_id=user_id,
|
|
||||||
decomposition_result=decomposition_result,
|
|
||||||
existing_agents=library_agents,
|
|
||||||
include_marketplace=True,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent_json = await generate_agent(
|
|
||||||
decomposition_result,
|
|
||||||
library_agents,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Agent generation is not available. "
|
|
||||||
"The Agent Generator service is not configured."
|
|
||||||
),
|
|
||||||
error="service_not_configured",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if agent_json is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
|
||||||
error="generation_failed",
|
|
||||||
details={"description": description[:100]},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
|
||||||
error_msg = agent_json.get("error", "Unknown error")
|
|
||||||
error_type = agent_json.get("error_type", "unknown")
|
|
||||||
user_message = get_user_message_for_error(
|
|
||||||
error_type,
|
|
||||||
operation="generate the agent",
|
|
||||||
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
|
||||||
validation_message=(
|
|
||||||
"I wasn't able to create a valid agent for this request. "
|
|
||||||
"The generated workflow had some structural issues. "
|
|
||||||
"Please try simplifying your goal or breaking it into smaller steps."
|
|
||||||
),
|
|
||||||
error_details=error_msg,
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=user_message,
|
|
||||||
error=f"generation_failed:{error_type}",
|
|
||||||
details={
|
|
||||||
"description": description[:100],
|
|
||||||
"service_error": error_msg,
|
|
||||||
"error_type": error_type,
|
|
||||||
},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if Agent Generator accepted for async processing
|
|
||||||
if agent_json.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Agent generation delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return AsyncProcessingResponse(
|
|
||||||
message="Agent generation started. You'll be notified when it's complete.",
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = agent_json.get("name", "Generated Agent")
|
|
||||||
agent_description = agent_json.get("description", "")
|
|
||||||
node_count = len(agent_json.get("nodes", []))
|
|
||||||
link_count = len(agent_json.get("links", []))
|
|
||||||
|
|
||||||
if not save:
|
|
||||||
return AgentPreviewResponse(
|
|
||||||
message=(
|
|
||||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
|
||||||
f"Review it and call create_agent with save=true to save it to your library."
|
|
||||||
),
|
|
||||||
agent_json=agent_json,
|
|
||||||
agent_name=agent_name,
|
|
||||||
description=agent_description,
|
|
||||||
node_count=node_count,
|
|
||||||
link_count=link_count,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="You must be logged in to save agents.",
|
|
||||||
error="auth_required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
agent_json, user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return AgentSavedResponse(
|
|
||||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
|
||||||
agent_id=created_graph.id,
|
|
||||||
agent_name=created_graph.name,
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to save the agent: {str(e)}",
|
|
||||||
error="save_failed",
|
|
||||||
details={"exception": str(e)},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -1,337 +0,0 @@
|
|||||||
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.store import db as store_db
|
|
||||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
|
||||||
|
|
||||||
from .agent_generator import (
|
|
||||||
AgentGeneratorNotConfiguredError,
|
|
||||||
customize_template,
|
|
||||||
get_user_message_for_error,
|
|
||||||
graph_to_json,
|
|
||||||
save_agent_to_library,
|
|
||||||
)
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import (
|
|
||||||
AgentPreviewResponse,
|
|
||||||
AgentSavedResponse,
|
|
||||||
ClarificationNeededResponse,
|
|
||||||
ClarifyingQuestion,
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomizeAgentTool(BaseTool):
|
|
||||||
"""Tool for customizing marketplace/template agents using natural language."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "customize_agent"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Customize a marketplace or template agent using natural language. "
|
|
||||||
"Takes an existing agent from the marketplace and modifies it based on "
|
|
||||||
"the user's requirements before adding to their library."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_long_running(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"agent_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The marketplace agent ID in format 'creator/slug' "
|
|
||||||
"(e.g., 'autogpt/newsletter-writer'). "
|
|
||||||
"Get this from find_agent results."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"modifications": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Natural language description of how to customize the agent. "
|
|
||||||
"Be specific about what changes you want to make."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"context": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Additional context or answers to previous clarifying questions."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"save": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"Whether to save the customized agent to the user's library. "
|
|
||||||
"Default is true. Set to false for preview only."
|
|
||||||
),
|
|
||||||
"default": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["agent_id", "modifications"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
"""Execute the customize_agent tool.
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Parse the agent ID to get creator/slug
|
|
||||||
2. Fetch the template agent from the marketplace
|
|
||||||
3. Call customize_template with the modification request
|
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
|
||||||
modifications = kwargs.get("modifications", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
if not agent_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
|
||||||
error="missing_agent_id",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not modifications:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please describe how you want to customize this agent.",
|
|
||||||
error="missing_modifications",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse agent_id in format "creator/slug"
|
|
||||||
parts = [p.strip() for p in agent_id.split("/")]
|
|
||||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Invalid agent ID format: '{agent_id}'. "
|
|
||||||
"Expected format is 'creator/agent-name' "
|
|
||||||
"(e.g., 'autogpt/newsletter-writer')."
|
|
||||||
),
|
|
||||||
error="invalid_agent_id_format",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
creator_username, agent_slug = parts
|
|
||||||
|
|
||||||
# Fetch the marketplace agent details
|
|
||||||
try:
|
|
||||||
agent_details = await store_db.get_store_agent_details(
|
|
||||||
username=creator_username, agent_name=agent_slug
|
|
||||||
)
|
|
||||||
except AgentNotFoundError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Could not find marketplace agent '{agent_id}'. "
|
|
||||||
"Please check the agent ID and try again."
|
|
||||||
),
|
|
||||||
error="agent_not_found",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to fetch the marketplace agent. Please try again.",
|
|
||||||
error="fetch_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not agent_details.store_listing_version_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"The agent '{agent_id}' does not have an available version. "
|
|
||||||
"Please try a different agent."
|
|
||||||
),
|
|
||||||
error="no_version_available",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the full agent graph
|
|
||||||
try:
|
|
||||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
|
||||||
template_agent = graph_to_json(graph)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to fetch the agent configuration. Please try again.",
|
|
||||||
error="graph_fetch_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call customize_template
|
|
||||||
try:
|
|
||||||
result = await customize_template(
|
|
||||||
template_agent=template_agent,
|
|
||||||
modification_request=modifications,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Agent customization is not available. "
|
|
||||||
"The Agent Generator service is not configured."
|
|
||||||
),
|
|
||||||
error="service_not_configured",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Failed to customize the agent due to a service error. "
|
|
||||||
"Please try again."
|
|
||||||
),
|
|
||||||
error="customization_service_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Failed to customize the agent. "
|
|
||||||
"The agent generation service may be unavailable or timed out. "
|
|
||||||
"Please try again."
|
|
||||||
),
|
|
||||||
error="customization_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle error response
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
|
||||||
error_msg = result.get("error", "Unknown error")
|
|
||||||
error_type = result.get("error_type", "unknown")
|
|
||||||
user_message = get_user_message_for_error(
|
|
||||||
error_type,
|
|
||||||
operation="customize the agent",
|
|
||||||
llm_parse_message=(
|
|
||||||
"The AI had trouble customizing the agent. "
|
|
||||||
"Please try again or simplify your request."
|
|
||||||
),
|
|
||||||
validation_message=(
|
|
||||||
"The customized agent failed validation. "
|
|
||||||
"Please try rephrasing your request."
|
|
||||||
),
|
|
||||||
error_details=error_msg,
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=user_message,
|
|
||||||
error=f"customization_failed:{error_type}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle clarifying questions
|
|
||||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
|
||||||
questions = result.get("questions") or []
|
|
||||||
if not isinstance(questions, list):
|
|
||||||
logger.error(
|
|
||||||
f"Unexpected clarifying questions format: {type(questions)}"
|
|
||||||
)
|
|
||||||
questions = []
|
|
||||||
return ClarificationNeededResponse(
|
|
||||||
message=(
|
|
||||||
"I need some more information to customize this agent. "
|
|
||||||
"Please answer the following questions:"
|
|
||||||
),
|
|
||||||
questions=[
|
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
if isinstance(q, dict)
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Result should be the customized agent JSON
|
|
||||||
if not isinstance(result, dict):
|
|
||||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to customize the agent due to an unexpected response.",
|
|
||||||
error="unexpected_response_type",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
customized_agent = result
|
|
||||||
|
|
||||||
agent_name = customized_agent.get(
|
|
||||||
"name", f"Customized {agent_details.agent_name}"
|
|
||||||
)
|
|
||||||
agent_description = customized_agent.get("description", "")
|
|
||||||
nodes = customized_agent.get("nodes")
|
|
||||||
links = customized_agent.get("links")
|
|
||||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
|
||||||
link_count = len(links) if isinstance(links, list) else 0
|
|
||||||
|
|
||||||
if not save:
|
|
||||||
return AgentPreviewResponse(
|
|
||||||
message=(
|
|
||||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
|
||||||
f"The customized agent has {node_count} blocks. "
|
|
||||||
f"Review it and call customize_agent with save=true to save it."
|
|
||||||
),
|
|
||||||
agent_json=customized_agent,
|
|
||||||
agent_name=agent_name,
|
|
||||||
description=agent_description,
|
|
||||||
node_count=node_count,
|
|
||||||
link_count=link_count,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="You must be logged in to save agents.",
|
|
||||||
error="auth_required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save to user's library
|
|
||||||
try:
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
customized_agent, user_id, is_update=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return AgentSavedResponse(
|
|
||||||
message=(
|
|
||||||
f"Customized agent '{created_graph.name}' "
|
|
||||||
f"(based on '{agent_details.agent_name}') "
|
|
||||||
f"has been saved to your library!"
|
|
||||||
),
|
|
||||||
agent_id=created_graph.id,
|
|
||||||
agent_name=created_graph.name,
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving customized agent: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to save the customized agent. Please try again.",
|
|
||||||
error="save_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -1,284 +0,0 @@
|
|||||||
"""EditAgentTool - Edits existing agents using natural language."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
|
|
||||||
from .agent_generator import (
|
|
||||||
AgentGeneratorNotConfiguredError,
|
|
||||||
generate_agent_patch,
|
|
||||||
get_agent_as_json,
|
|
||||||
get_all_relevant_agents_for_generation,
|
|
||||||
get_user_message_for_error,
|
|
||||||
save_agent_to_library,
|
|
||||||
)
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import (
|
|
||||||
AgentPreviewResponse,
|
|
||||||
AgentSavedResponse,
|
|
||||||
AsyncProcessingResponse,
|
|
||||||
ClarificationNeededResponse,
|
|
||||||
ClarifyingQuestion,
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class EditAgentTool(BaseTool):
|
|
||||||
"""Tool for editing existing agents using natural language."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "edit_agent"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Edit an existing agent from the user's library using natural language. "
|
|
||||||
"Generates updates to the agent while preserving unchanged parts."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_long_running(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"agent_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The ID of the agent to edit. "
|
|
||||||
"Can be a graph ID or library agent ID."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"changes": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Natural language description of what changes to make. "
|
|
||||||
"Be specific about what to add, remove, or modify."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"context": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Additional context or answers to previous clarifying questions."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"save": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"Whether to save the changes. "
|
|
||||||
"Default is true. Set to false for preview only."
|
|
||||||
),
|
|
||||||
"default": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["agent_id", "changes"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
"""Execute the edit_agent tool.
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Fetch the current agent
|
|
||||||
2. Generate updated agent (external service handles fixing and validation)
|
|
||||||
3. Preview or save based on the save parameter
|
|
||||||
"""
|
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
|
||||||
changes = kwargs.get("changes", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
|
||||||
operation_id = kwargs.get("_operation_id")
|
|
||||||
task_id = kwargs.get("_task_id")
|
|
||||||
|
|
||||||
if not agent_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide the agent ID to edit.",
|
|
||||||
error="Missing agent_id parameter",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not changes:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please describe what changes you want to make.",
|
|
||||||
error="Missing changes parameter",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
|
||||||
|
|
||||||
if current_agent is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
|
||||||
error="agent_not_found",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
library_agents = None
|
|
||||||
if user_id:
|
|
||||||
try:
|
|
||||||
graph_id = current_agent.get("id")
|
|
||||||
library_agents = await get_all_relevant_agents_for_generation(
|
|
||||||
user_id=user_id,
|
|
||||||
search_query=changes,
|
|
||||||
exclude_graph_id=graph_id,
|
|
||||||
include_marketplace=True,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to fetch library agents: {e}")
|
|
||||||
|
|
||||||
update_request = changes
|
|
||||||
if context:
|
|
||||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await generate_agent_patch(
|
|
||||||
update_request,
|
|
||||||
current_agent,
|
|
||||||
library_agents,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Agent editing is not available. "
|
|
||||||
"The Agent Generator service is not configured."
|
|
||||||
),
|
|
||||||
error="service_not_configured",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
|
||||||
error="update_generation_failed",
|
|
||||||
details={"agent_id": agent_id, "changes": changes[:100]},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if Agent Generator accepted for async processing
|
|
||||||
if result.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Agent edit delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return AsyncProcessingResponse(
|
|
||||||
message="Agent edit started. You'll be notified when it's complete.",
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the result is an error from the external service
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
|
||||||
error_msg = result.get("error", "Unknown error")
|
|
||||||
error_type = result.get("error_type", "unknown")
|
|
||||||
user_message = get_user_message_for_error(
|
|
||||||
error_type,
|
|
||||||
operation="generate the changes",
|
|
||||||
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
|
||||||
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
|
||||||
error_details=error_msg,
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=user_message,
|
|
||||||
error=f"update_generation_failed:{error_type}",
|
|
||||||
details={
|
|
||||||
"agent_id": agent_id,
|
|
||||||
"changes": changes[:100],
|
|
||||||
"service_error": error_msg,
|
|
||||||
"error_type": error_type,
|
|
||||||
},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.get("type") == "clarifying_questions":
|
|
||||||
questions = result.get("questions", [])
|
|
||||||
return ClarificationNeededResponse(
|
|
||||||
message=(
|
|
||||||
"I need some more information about the changes. "
|
|
||||||
"Please answer the following questions:"
|
|
||||||
),
|
|
||||||
questions=[
|
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_agent = result
|
|
||||||
|
|
||||||
agent_name = updated_agent.get("name", "Updated Agent")
|
|
||||||
agent_description = updated_agent.get("description", "")
|
|
||||||
node_count = len(updated_agent.get("nodes", []))
|
|
||||||
link_count = len(updated_agent.get("links", []))
|
|
||||||
|
|
||||||
if not save:
|
|
||||||
return AgentPreviewResponse(
|
|
||||||
message=(
|
|
||||||
f"I've updated the agent. "
|
|
||||||
f"The agent now has {node_count} blocks. "
|
|
||||||
f"Review it and call edit_agent with save=true to save the changes."
|
|
||||||
),
|
|
||||||
agent_json=updated_agent,
|
|
||||||
agent_name=agent_name,
|
|
||||||
description=agent_description,
|
|
||||||
node_count=node_count,
|
|
||||||
link_count=link_count,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="You must be logged in to save agents.",
|
|
||||||
error="auth_required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
updated_agent, user_id, is_update=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return AgentSavedResponse(
|
|
||||||
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
|
||||||
agent_id=created_graph.id,
|
|
||||||
agent_name=created_graph.name,
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to save the updated agent: {str(e)}",
|
|
||||||
error="save_failed",
|
|
||||||
details={"exception": str(e)},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
"""Tests for block filtering in FindBlockTool."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.chat.tools.find_block import (
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
|
||||||
FindBlockTool,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.tools.models import BlockListResponse
|
|
||||||
from backend.blocks._base import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-find-block"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block(
|
|
||||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
|
||||||
):
|
|
||||||
"""Create a mock block for testing."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.description = f"{name} description"
|
|
||||||
mock.block_type = block_type
|
|
||||||
mock.disabled = disabled
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
|
||||||
mock.input_schema.get_credentials_fields.return_value = {}
|
|
||||||
mock.output_schema = MagicMock()
|
|
||||||
mock.output_schema.jsonschema.return_value = {}
|
|
||||||
mock.categories = []
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestFindBlockFiltering:
|
|
||||||
"""Tests for block filtering in FindBlockTool."""
|
|
||||||
|
|
||||||
def test_excluded_block_types_contains_expected_types(self):
|
|
||||||
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
|
||||||
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
|
|
||||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
|
||||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
|
||||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_type_filtered_from_results(self):
|
|
||||||
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
|
||||||
search_results = [
|
|
||||||
{"content_id": "input-block-id", "score": 0.9},
|
|
||||||
{"content_id": "standard-block-id", "score": 0.8},
|
|
||||||
]
|
|
||||||
|
|
||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
|
||||||
standard_block = make_mock_block(
|
|
||||||
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_get_block(block_id):
|
|
||||||
return {
|
|
||||||
"input-block-id": input_block,
|
|
||||||
"standard-block-id": standard_block,
|
|
||||||
}.get(block_id)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
|
||||||
side_effect=mock_get_block,
|
|
||||||
):
|
|
||||||
tool = FindBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID, session=session, query="test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should only return the standard block, not the INPUT block
|
|
||||||
assert isinstance(response, BlockListResponse)
|
|
||||||
assert len(response.blocks) == 1
|
|
||||||
assert response.blocks[0].id == "standard-block-id"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_id_filtered_from_results(self):
|
|
||||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
|
||||||
search_results = [
|
|
||||||
{"content_id": smart_decision_id, "score": 0.9},
|
|
||||||
{"content_id": "normal-block-id", "score": 0.8},
|
|
||||||
]
|
|
||||||
|
|
||||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
|
||||||
smart_block = make_mock_block(
|
|
||||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
normal_block = make_mock_block(
|
|
||||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_get_block(block_id):
|
|
||||||
return {
|
|
||||||
smart_decision_id: smart_block,
|
|
||||||
"normal-block-id": normal_block,
|
|
||||||
}.get(block_id)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
|
||||||
side_effect=mock_get_block,
|
|
||||||
):
|
|
||||||
tool = FindBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should only return normal block, not SmartDecisionMakerBlock
|
|
||||||
assert isinstance(response, BlockListResponse)
|
|
||||||
assert len(response.blocks) == 1
|
|
||||||
assert response.blocks[0].id == "normal-block-id"
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
"""Shared helpers for chat tools."""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def get_inputs_from_schema(
|
|
||||||
input_schema: dict[str, Any],
|
|
||||||
exclude_fields: set[str] | None = None,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Extract input field info from JSON schema."""
|
|
||||||
if not isinstance(input_schema, dict):
|
|
||||||
return []
|
|
||||||
|
|
||||||
exclude = exclude_fields or set()
|
|
||||||
properties = input_schema.get("properties", {})
|
|
||||||
required = set(input_schema.get("required", []))
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"name": name,
|
|
||||||
"title": schema.get("title", name),
|
|
||||||
"type": schema.get("type", "string"),
|
|
||||||
"description": schema.get("description", ""),
|
|
||||||
"required": name in required,
|
|
||||||
"default": schema.get("default"),
|
|
||||||
}
|
|
||||||
for name, schema in properties.items()
|
|
||||||
if name not in exclude
|
|
||||||
]
|
|
||||||
@@ -1,423 +0,0 @@
|
|||||||
"""Pydantic models for tool responses."""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from backend.data.model import CredentialsMetaInput
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseType(str, Enum):
|
|
||||||
"""Types of tool responses."""
|
|
||||||
|
|
||||||
AGENTS_FOUND = "agents_found"
|
|
||||||
AGENT_DETAILS = "agent_details"
|
|
||||||
SETUP_REQUIREMENTS = "setup_requirements"
|
|
||||||
EXECUTION_STARTED = "execution_started"
|
|
||||||
NEED_LOGIN = "need_login"
|
|
||||||
ERROR = "error"
|
|
||||||
NO_RESULTS = "no_results"
|
|
||||||
AGENT_OUTPUT = "agent_output"
|
|
||||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
|
||||||
AGENT_PREVIEW = "agent_preview"
|
|
||||||
AGENT_SAVED = "agent_saved"
|
|
||||||
CLARIFICATION_NEEDED = "clarification_needed"
|
|
||||||
BLOCK_LIST = "block_list"
|
|
||||||
BLOCK_OUTPUT = "block_output"
|
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
|
||||||
DOC_PAGE = "doc_page"
|
|
||||||
# Workspace response types
|
|
||||||
WORKSPACE_FILE_LIST = "workspace_file_list"
|
|
||||||
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
|
||||||
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
|
||||||
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
|
||||||
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
|
||||||
# Long-running operation types
|
|
||||||
OPERATION_STARTED = "operation_started"
|
|
||||||
OPERATION_PENDING = "operation_pending"
|
|
||||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
|
||||||
# Input validation
|
|
||||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
|
||||||
class ToolResponseBase(BaseModel):
|
|
||||||
"""Base model for all tool responses."""
|
|
||||||
|
|
||||||
type: ResponseType
|
|
||||||
message: str
|
|
||||||
session_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Agent discovery models
|
|
||||||
class AgentInfo(BaseModel):
|
|
||||||
"""Information about an agent."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
source: str = Field(description="marketplace or library")
|
|
||||||
in_library: bool = False
|
|
||||||
creator: str | None = None
|
|
||||||
category: str | None = None
|
|
||||||
rating: float | None = None
|
|
||||||
runs: int | None = None
|
|
||||||
is_featured: bool | None = None
|
|
||||||
status: str | None = None
|
|
||||||
can_access_graph: bool | None = None
|
|
||||||
has_external_trigger: bool | None = None
|
|
||||||
new_output: bool | None = None
|
|
||||||
graph_id: str | None = None
|
|
||||||
inputs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Input schema for the agent, including field names, types, and defaults",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentsFoundResponse(ToolResponseBase):
|
|
||||||
"""Response for find_agent tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.AGENTS_FOUND
|
|
||||||
title: str = "Available Agents"
|
|
||||||
agents: list[AgentInfo]
|
|
||||||
count: int
|
|
||||||
name: str = "agents_found"
|
|
||||||
|
|
||||||
|
|
||||||
class NoResultsResponse(ToolResponseBase):
|
|
||||||
"""Response when no agents found."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.NO_RESULTS
|
|
||||||
suggestions: list[str] = []
|
|
||||||
name: str = "no_results"
|
|
||||||
|
|
||||||
|
|
||||||
# Agent details models
|
|
||||||
class InputField(BaseModel):
|
|
||||||
"""Input field specification."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
type: str = "string"
|
|
||||||
description: str = ""
|
|
||||||
required: bool = False
|
|
||||||
default: Any | None = None
|
|
||||||
options: list[Any] | None = None
|
|
||||||
format: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionOptions(BaseModel):
|
|
||||||
"""Available execution options for an agent."""
|
|
||||||
|
|
||||||
manual: bool = True
|
|
||||||
scheduled: bool = True
|
|
||||||
webhook: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class AgentDetails(BaseModel):
|
|
||||||
"""Detailed agent information."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
in_library: bool = False
|
|
||||||
inputs: dict[str, Any] = {}
|
|
||||||
credentials: list[CredentialsMetaInput] = []
|
|
||||||
execution_options: ExecutionOptions = Field(default_factory=ExecutionOptions)
|
|
||||||
trigger_info: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentDetailsResponse(ToolResponseBase):
|
|
||||||
"""Response for get_details action."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.AGENT_DETAILS
|
|
||||||
agent: AgentDetails
|
|
||||||
user_authenticated: bool = False
|
|
||||||
graph_id: str | None = None
|
|
||||||
graph_version: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Setup info models
|
|
||||||
class UserReadiness(BaseModel):
|
|
||||||
"""User readiness status."""
|
|
||||||
|
|
||||||
has_all_credentials: bool = False
|
|
||||||
missing_credentials: dict[str, Any] = {}
|
|
||||||
ready_to_run: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class SetupInfo(BaseModel):
|
|
||||||
"""Complete setup information."""
|
|
||||||
|
|
||||||
agent_id: str
|
|
||||||
agent_name: str
|
|
||||||
requirements: dict[str, list[Any]] = Field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"credentials": [],
|
|
||||||
"inputs": [],
|
|
||||||
"execution_modes": [],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
user_readiness: UserReadiness = Field(default_factory=UserReadiness)
|
|
||||||
|
|
||||||
|
|
||||||
class SetupRequirementsResponse(ToolResponseBase):
|
|
||||||
"""Response for validate action."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.SETUP_REQUIREMENTS
|
|
||||||
setup_info: SetupInfo
|
|
||||||
graph_id: str | None = None
|
|
||||||
graph_version: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Execution models
|
|
||||||
class ExecutionStartedResponse(ToolResponseBase):
|
|
||||||
"""Response for run/schedule actions."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.EXECUTION_STARTED
|
|
||||||
execution_id: str
|
|
||||||
graph_id: str
|
|
||||||
graph_name: str
|
|
||||||
library_agent_id: str | None = None
|
|
||||||
library_agent_link: str | None = None
|
|
||||||
status: str = "QUEUED"
|
|
||||||
|
|
||||||
|
|
||||||
# Auth/error models
|
|
||||||
class NeedLoginResponse(ToolResponseBase):
|
|
||||||
"""Response when login is needed."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.NEED_LOGIN
|
|
||||||
agent_info: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(ToolResponseBase):
|
|
||||||
"""Response for errors."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.ERROR
|
|
||||||
error: str | None = None
|
|
||||||
details: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class InputValidationErrorResponse(ToolResponseBase):
|
|
||||||
"""Response when run_agent receives unknown input fields."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
|
||||||
unrecognized_fields: list[str] = Field(
|
|
||||||
description="List of input field names that were not recognized"
|
|
||||||
)
|
|
||||||
inputs: dict[str, Any] = Field(
|
|
||||||
description="The agent's valid input schema for reference"
|
|
||||||
)
|
|
||||||
graph_id: str | None = None
|
|
||||||
graph_version: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Agent output models
|
|
||||||
class ExecutionOutputInfo(BaseModel):
|
|
||||||
"""Summary of a single execution's outputs."""
|
|
||||||
|
|
||||||
execution_id: str
|
|
||||||
status: str
|
|
||||||
started_at: datetime | None = None
|
|
||||||
ended_at: datetime | None = None
|
|
||||||
outputs: dict[str, list[Any]]
|
|
||||||
inputs_summary: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentOutputResponse(ToolResponseBase):
|
|
||||||
"""Response for agent_output tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.AGENT_OUTPUT
|
|
||||||
agent_name: str
|
|
||||||
agent_id: str
|
|
||||||
library_agent_id: str | None = None
|
|
||||||
library_agent_link: str | None = None
|
|
||||||
execution: ExecutionOutputInfo | None = None
|
|
||||||
available_executions: list[dict[str, Any]] | None = None
|
|
||||||
total_executions: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
# Business understanding models
|
|
||||||
class UnderstandingUpdatedResponse(ToolResponseBase):
|
|
||||||
"""Response for add_understanding tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
|
|
||||||
updated_fields: list[str] = Field(default_factory=list)
|
|
||||||
current_understanding: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
# Agent generation models
|
|
||||||
class ClarifyingQuestion(BaseModel):
|
|
||||||
"""A question that needs user clarification."""
|
|
||||||
|
|
||||||
question: str
|
|
||||||
keyword: str
|
|
||||||
example: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentPreviewResponse(ToolResponseBase):
|
|
||||||
"""Response for previewing a generated agent before saving."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.AGENT_PREVIEW
|
|
||||||
agent_json: dict[str, Any]
|
|
||||||
agent_name: str
|
|
||||||
description: str
|
|
||||||
node_count: int
|
|
||||||
link_count: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class AgentSavedResponse(ToolResponseBase):
|
|
||||||
"""Response when an agent is saved to the library."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.AGENT_SAVED
|
|
||||||
agent_id: str
|
|
||||||
agent_name: str
|
|
||||||
library_agent_id: str
|
|
||||||
library_agent_link: str
|
|
||||||
agent_page_link: str # Link to the agent builder/editor page
|
|
||||||
|
|
||||||
|
|
||||||
class ClarificationNeededResponse(ToolResponseBase):
|
|
||||||
"""Response when the LLM needs more information from the user."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
|
|
||||||
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# Documentation search models
|
|
||||||
class DocSearchResult(BaseModel):
|
|
||||||
"""A single documentation search result."""
|
|
||||||
|
|
||||||
title: str
|
|
||||||
path: str
|
|
||||||
section: str
|
|
||||||
snippet: str # Short excerpt for UI display
|
|
||||||
score: float
|
|
||||||
doc_url: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class DocSearchResultsResponse(ToolResponseBase):
|
|
||||||
"""Response for search_docs tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
|
|
||||||
results: list[DocSearchResult]
|
|
||||||
count: int
|
|
||||||
query: str
|
|
||||||
|
|
||||||
|
|
||||||
class DocPageResponse(ToolResponseBase):
|
|
||||||
"""Response for get_doc_page tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.DOC_PAGE
|
|
||||||
title: str
|
|
||||||
path: str
|
|
||||||
content: str # Full document content
|
|
||||||
doc_url: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Block models
|
|
||||||
class BlockInputFieldInfo(BaseModel):
|
|
||||||
"""Information about a block input field."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
type: str
|
|
||||||
description: str = ""
|
|
||||||
required: bool = False
|
|
||||||
default: Any | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class BlockInfoSummary(BaseModel):
|
|
||||||
"""Summary of a block for search results."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
categories: list[str]
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
output_schema: dict[str, Any]
|
|
||||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="List of required input fields for this block",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BlockListResponse(ToolResponseBase):
|
|
||||||
"""Response for find_block tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.BLOCK_LIST
|
|
||||||
blocks: list[BlockInfoSummary]
|
|
||||||
count: int
|
|
||||||
query: str
|
|
||||||
usage_hint: str = Field(
|
|
||||||
default="To execute a block, call run_block with block_id set to the block's "
|
|
||||||
"'id' field and input_data containing the required fields from input_schema."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BlockOutputResponse(ToolResponseBase):
|
|
||||||
"""Response for run_block tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.BLOCK_OUTPUT
|
|
||||||
block_id: str
|
|
||||||
block_name: str
|
|
||||||
outputs: dict[str, list[Any]]
|
|
||||||
success: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
# Long-running operation models
|
|
||||||
class OperationStartedResponse(ToolResponseBase):
|
|
||||||
"""Response when a long-running operation has been started in the background.
|
|
||||||
|
|
||||||
This is returned immediately to the client while the operation continues
|
|
||||||
to execute. The user can close the tab and check back later.
|
|
||||||
|
|
||||||
The task_id can be used to reconnect to the SSE stream via
|
|
||||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
|
||||||
operation_id: str
|
|
||||||
tool_name: str
|
|
||||||
task_id: str | None = None # For SSE reconnection
|
|
||||||
|
|
||||||
|
|
||||||
class OperationPendingResponse(ToolResponseBase):
|
|
||||||
"""Response stored in chat history while a long-running operation is executing.
|
|
||||||
|
|
||||||
This is persisted to the database so users see a pending state when they
|
|
||||||
refresh before the operation completes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_PENDING
|
|
||||||
operation_id: str
|
|
||||||
tool_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class OperationInProgressResponse(ToolResponseBase):
|
|
||||||
"""Response when an operation is already in progress.
|
|
||||||
|
|
||||||
Returned for idempotency when the same tool_call_id is requested again
|
|
||||||
while the background task is still running.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
|
||||||
tool_call_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncProcessingResponse(ToolResponseBase):
|
|
||||||
"""Response when an operation has been delegated to async processing.
|
|
||||||
|
|
||||||
This is returned by tools when the external service accepts the request
|
|
||||||
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
|
||||||
consumer will handle the result when the external service completes.
|
|
||||||
|
|
||||||
The status field is specifically "accepted" to allow the long-running tool
|
|
||||||
handler to detect this response and skip LLM continuation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
|
||||||
status: str = "accepted" # Must be "accepted" for detection
|
|
||||||
operation_id: str | None = None
|
|
||||||
task_id: str | None = None
|
|
||||||
@@ -1,356 +0,0 @@
|
|||||||
"""Tool for executing blocks directly."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic_core import PydanticUndefined
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools.find_block import (
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
|
||||||
)
|
|
||||||
from backend.blocks import get_block
|
|
||||||
from backend.blocks._base import AnyBlockSchema
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
|
||||||
from backend.data.workspace import get_or_create_workspace
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
from backend.util.exceptions import BlockError
|
|
||||||
|
|
||||||
from .base import BaseTool
|
|
||||||
from .helpers import get_inputs_from_schema
|
|
||||||
from .models import (
|
|
||||||
BlockOutputResponse,
|
|
||||||
ErrorResponse,
|
|
||||||
SetupInfo,
|
|
||||||
SetupRequirementsResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
UserReadiness,
|
|
||||||
)
|
|
||||||
from .utils import (
|
|
||||||
build_missing_credentials_from_field_info,
|
|
||||||
match_credentials_to_requirements,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RunBlockTool(BaseTool):
|
|
||||||
"""Tool for executing a block and returning its outputs."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "run_block"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Execute a specific block with the provided input data. "
|
|
||||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
|
||||||
"do NOT guess or make up block IDs. "
|
|
||||||
"Use the 'id' from find_block results and provide input_data "
|
|
||||||
"matching the block's required_inputs."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"block_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The block's 'id' field from find_block results. "
|
|
||||||
"NEVER guess this - always get it from find_block first."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"input_data": {
|
|
||||||
"type": "object",
|
|
||||||
"description": (
|
|
||||||
"Input values for the block. Use the 'required_inputs' field "
|
|
||||||
"from find_block to see what fields are needed."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["block_id", "input_data"],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
"""Execute a block with the given input data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID (required)
|
|
||||||
session: Chat session
|
|
||||||
block_id: Block UUID to execute
|
|
||||||
input_data: Input values for the block
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BlockOutputResponse: Block execution outputs
|
|
||||||
SetupRequirementsResponse: Missing credentials
|
|
||||||
ErrorResponse: Error message
|
|
||||||
"""
|
|
||||||
block_id = kwargs.get("block_id", "").strip()
|
|
||||||
input_data = kwargs.get("input_data", {})
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
if not block_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide a block_id",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(input_data, dict):
|
|
||||||
return ErrorResponse(
|
|
||||||
message="input_data must be an object",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Authentication required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the block
|
|
||||||
block = get_block(block_id)
|
|
||||||
if not block:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Block '{block_id}' not found",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
if block.disabled:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Block '{block_id}' is disabled",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if block is excluded from CoPilot (graph-only blocks)
|
|
||||||
if (
|
|
||||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
|
||||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
|
||||||
):
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
|
||||||
"This block is designed for use within graphs only."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
|
||||||
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
matched_credentials, missing_credentials = (
|
|
||||||
await self._resolve_block_credentials(user_id, block, input_data)
|
|
||||||
)
|
|
||||||
|
|
||||||
if missing_credentials:
|
|
||||||
# Return setup requirements response with missing credentials
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
missing_creds_dict = build_missing_credentials_from_field_info(
|
|
||||||
credentials_fields_info, set(matched_credentials.keys())
|
|
||||||
)
|
|
||||||
missing_creds_list = list(missing_creds_dict.values())
|
|
||||||
|
|
||||||
return SetupRequirementsResponse(
|
|
||||||
message=(
|
|
||||||
f"Block '{block.name}' requires credentials that are not configured. "
|
|
||||||
"Please set up the required credentials before running this block."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
setup_info=SetupInfo(
|
|
||||||
agent_id=block_id,
|
|
||||||
agent_name=block.name,
|
|
||||||
user_readiness=UserReadiness(
|
|
||||||
has_all_credentials=False,
|
|
||||||
missing_credentials=missing_creds_dict,
|
|
||||||
ready_to_run=False,
|
|
||||||
),
|
|
||||||
requirements={
|
|
||||||
"credentials": missing_creds_list,
|
|
||||||
"inputs": self._get_inputs_list(block),
|
|
||||||
"execution_modes": ["immediate"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
graph_id=None,
|
|
||||||
graph_version=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get or create user's workspace for CoPilot file operations
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
|
|
||||||
# Generate synthetic IDs for CoPilot context
|
|
||||||
# Each chat session is treated as its own agent with one continuous run
|
|
||||||
# This means:
|
|
||||||
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
|
||||||
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
|
||||||
# - node_exec_id = unique per block execution
|
|
||||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
|
||||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
|
||||||
synthetic_node_id = f"copilot-node-{block_id}"
|
|
||||||
synthetic_node_exec_id = (
|
|
||||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create unified execution context with all required fields
|
|
||||||
execution_context = ExecutionContext(
|
|
||||||
# Execution identity
|
|
||||||
user_id=user_id,
|
|
||||||
graph_id=synthetic_graph_id,
|
|
||||||
graph_exec_id=synthetic_graph_exec_id,
|
|
||||||
graph_version=1, # Versions are 1-indexed
|
|
||||||
node_id=synthetic_node_id,
|
|
||||||
node_exec_id=synthetic_node_exec_id,
|
|
||||||
# Workspace with session scoping
|
|
||||||
workspace_id=workspace.id,
|
|
||||||
session_id=session.session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare kwargs for block execution
|
|
||||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
|
||||||
exec_kwargs: dict[str, Any] = {
|
|
||||||
"user_id": user_id,
|
|
||||||
"execution_context": execution_context,
|
|
||||||
# Legacy: individual kwargs for blocks not yet using execution_context
|
|
||||||
"workspace_id": workspace.id,
|
|
||||||
"graph_exec_id": synthetic_graph_exec_id,
|
|
||||||
"node_exec_id": synthetic_node_exec_id,
|
|
||||||
"node_id": synthetic_node_id,
|
|
||||||
"graph_version": 1, # Versions are 1-indexed
|
|
||||||
"graph_id": synthetic_graph_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
for field_name, cred_meta in matched_credentials.items():
|
|
||||||
# Inject metadata into input_data (for validation)
|
|
||||||
if field_name not in input_data:
|
|
||||||
input_data[field_name] = cred_meta.model_dump()
|
|
||||||
|
|
||||||
# Fetch actual credentials and pass as kwargs (for execution)
|
|
||||||
actual_credentials = await creds_manager.get(
|
|
||||||
user_id, cred_meta.id, lock=False
|
|
||||||
)
|
|
||||||
if actual_credentials:
|
|
||||||
exec_kwargs[field_name] = actual_credentials
|
|
||||||
else:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to retrieve credentials for {field_name}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute the block and collect outputs
|
|
||||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
|
||||||
async for output_name, output_data in block.execute(
|
|
||||||
input_data,
|
|
||||||
**exec_kwargs,
|
|
||||||
):
|
|
||||||
outputs[output_name].append(output_data)
|
|
||||||
|
|
||||||
return BlockOutputResponse(
|
|
||||||
message=f"Block '{block.name}' executed successfully",
|
|
||||||
block_id=block_id,
|
|
||||||
block_name=block.name,
|
|
||||||
outputs=dict(outputs),
|
|
||||||
success=True,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except BlockError as e:
|
|
||||||
logger.warning(f"Block execution failed: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Block execution failed: {e}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to execute block: {str(e)}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _resolve_block_credentials(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
block: AnyBlockSchema,
|
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Resolve credentials for a block by matching user's available credentials.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to resolve credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple of (matched_credentials, missing_credentials) - matched credentials
|
|
||||||
are used for block execution, missing ones indicate setup requirements.
|
|
||||||
"""
|
|
||||||
input_data = input_data or {}
|
|
||||||
requirements = self._resolve_discriminated_credentials(block, input_data)
|
|
||||||
|
|
||||||
if not requirements:
|
|
||||||
return {}, []
|
|
||||||
|
|
||||||
return await match_credentials_to_requirements(user_id, requirements)
|
|
||||||
|
|
||||||
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
|
||||||
"""Extract non-credential inputs from block schema."""
|
|
||||||
schema = block.input_schema.jsonschema()
|
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
|
||||||
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
|
||||||
|
|
||||||
def _resolve_discriminated_credentials(
|
|
||||||
self,
|
|
||||||
block: AnyBlockSchema,
|
|
||||||
input_data: dict[str, Any],
|
|
||||||
) -> dict[str, CredentialsFieldInfo]:
|
|
||||||
"""Resolve credential requirements, applying discriminator logic where needed."""
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
if not credentials_fields_info:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
resolved: dict[str, CredentialsFieldInfo] = {}
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
effective_field_info = field_info
|
|
||||||
|
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
# For host-scoped credentials, add the discriminator value
|
|
||||||
# (e.g., URL) so _credential_is_for_host can match it
|
|
||||||
effective_field_info.discriminator_values.add(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved[field_name] = effective_field_info
|
|
||||||
|
|
||||||
return resolved
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
"""Tests for block execution guards in RunBlockTool."""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.api.features.chat.tools.models import ErrorResponse
|
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
|
||||||
from backend.blocks._base import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-run-block"
|
|
||||||
|
|
||||||
|
|
||||||
def make_mock_block(
|
|
||||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
|
||||||
):
|
|
||||||
"""Create a mock block for testing."""
|
|
||||||
mock = MagicMock()
|
|
||||||
mock.id = block_id
|
|
||||||
mock.name = name
|
|
||||||
mock.block_type = block_type
|
|
||||||
mock.disabled = disabled
|
|
||||||
mock.input_schema = MagicMock()
|
|
||||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
|
||||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
|
||||||
return mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunBlockFiltering:
|
|
||||||
"""Tests for block execution guards in RunBlockTool."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_type_returns_error(self):
|
|
||||||
"""Attempting to execute a block with excluded BlockType returns error."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=input_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="input-block-id",
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, ErrorResponse)
|
|
||||||
assert "cannot be run directly in CoPilot" in response.message
|
|
||||||
assert "designed for use within graphs only" in response.message
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_excluded_block_id_returns_error(self):
|
|
||||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
|
||||||
smart_block = make_mock_block(
|
|
||||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=smart_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id=smart_decision_id,
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, ErrorResponse)
|
|
||||||
assert "cannot be run directly in CoPilot" in response.message
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_non_excluded_block_passes_guard(self):
|
|
||||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
|
||||||
|
|
||||||
standard_block = make_mock_block(
|
|
||||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
|
||||||
return_value=standard_block,
|
|
||||||
):
|
|
||||||
tool = RunBlockTool()
|
|
||||||
response = await tool._execute(
|
|
||||||
user_id=_TEST_USER_ID,
|
|
||||||
session=session,
|
|
||||||
block_id="standard-id",
|
|
||||||
input_data={},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
|
||||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
|
||||||
if isinstance(response, ErrorResponse):
|
|
||||||
assert "cannot be run directly in CoPilot" not in response.message
|
|
||||||
@@ -1,620 +0,0 @@
|
|||||||
"""CoPilot tools for workspace file operations."""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.data.workspace import get_or_create_workspace
|
|
||||||
from backend.util.settings import Config
|
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
|
||||||
from backend.util.workspace import WorkspaceManager
|
|
||||||
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceFileInfoData(BaseModel):
|
|
||||||
"""Data model for workspace file information (not a response itself)."""
|
|
||||||
|
|
||||||
file_id: str
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
mime_type: str
|
|
||||||
size_bytes: int
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceFileListResponse(ToolResponseBase):
|
|
||||||
"""Response containing list of workspace files."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
|
||||||
files: list[WorkspaceFileInfoData]
|
|
||||||
total_count: int
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceFileContentResponse(ToolResponseBase):
|
|
||||||
"""Response containing workspace file content (legacy, for small text files)."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
|
||||||
file_id: str
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
mime_type: str
|
|
||||||
content_base64: str
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
|
||||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
|
||||||
file_id: str
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
mime_type: str
|
|
||||||
size_bytes: int
|
|
||||||
download_url: str
|
|
||||||
preview: str | None = None # First 500 chars for text files
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceWriteResponse(ToolResponseBase):
|
|
||||||
"""Response after writing a file to workspace."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
|
||||||
file_id: str
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
size_bytes: int
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
|
||||||
"""Response after deleting a file from workspace."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
|
||||||
file_id: str
|
|
||||||
success: bool
|
|
||||||
|
|
||||||
|
|
||||||
class ListWorkspaceFilesTool(BaseTool):
|
|
||||||
"""Tool for listing files in user's workspace."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "list_workspace_files"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"List files in the user's workspace. "
|
|
||||||
"Returns file names, paths, sizes, and metadata. "
|
|
||||||
"Optionally filter by path prefix."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"path_prefix": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Optional path prefix to filter files "
|
|
||||||
"(e.g., '/documents/' to list only files in documents folder). "
|
|
||||||
"By default, only files from the current session are listed."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Maximum number of files to return (default 50, max 100)",
|
|
||||||
"minimum": 1,
|
|
||||||
"maximum": 100,
|
|
||||||
},
|
|
||||||
"include_all_sessions": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"If true, list files from all sessions. "
|
|
||||||
"Default is false (only current session's files)."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Authentication required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
|
||||||
limit = min(kwargs.get("limit", 50), 100)
|
|
||||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
|
||||||
|
|
||||||
try:
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Pass session_id for session-scoped file access
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
|
||||||
|
|
||||||
files = await manager.list_files(
|
|
||||||
path=path_prefix,
|
|
||||||
limit=limit,
|
|
||||||
include_all_sessions=include_all_sessions,
|
|
||||||
)
|
|
||||||
total = await manager.get_file_count(
|
|
||||||
path=path_prefix,
|
|
||||||
include_all_sessions=include_all_sessions,
|
|
||||||
)
|
|
||||||
|
|
||||||
file_infos = [
|
|
||||||
WorkspaceFileInfoData(
|
|
||||||
file_id=f.id,
|
|
||||||
name=f.name,
|
|
||||||
path=f.path,
|
|
||||||
mime_type=f.mimeType,
|
|
||||||
size_bytes=f.sizeBytes,
|
|
||||||
)
|
|
||||||
for f in files
|
|
||||||
]
|
|
||||||
|
|
||||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
|
||||||
return WorkspaceFileListResponse(
|
|
||||||
files=file_infos,
|
|
||||||
total_count=total,
|
|
||||||
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to list workspace files: {str(e)}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReadWorkspaceFileTool(BaseTool):
|
|
||||||
"""Tool for reading file content from workspace."""
|
|
||||||
|
|
||||||
# Size threshold for returning full content vs metadata+URL
|
|
||||||
# Files larger than this return metadata with download URL to prevent context bloat
|
|
||||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
|
||||||
# Preview size for text files
|
|
||||||
PREVIEW_SIZE = 500
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "read_workspace_file"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Read a file from the user's workspace. "
|
|
||||||
"Specify either file_id or path to identify the file. "
|
|
||||||
"For small text files, returns content directly. "
|
|
||||||
"For large or binary files, returns metadata and a download URL. "
|
|
||||||
"Paths are scoped to the current session by default. "
|
|
||||||
"Use /sessions/<session_id>/... for cross-session access."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"file_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The file's unique ID (from list_workspace_files)",
|
|
||||||
},
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
|
||||||
"Scoped to current session by default."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"force_download_url": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"If true, always return metadata+URL instead of inline content. "
|
|
||||||
"Default is false (auto-selects based on file size/type)."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": [], # At least one must be provided
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _is_text_mime_type(self, mime_type: str) -> bool:
|
|
||||||
"""Check if the MIME type is a text-based type."""
|
|
||||||
text_types = [
|
|
||||||
"text/",
|
|
||||||
"application/json",
|
|
||||||
"application/xml",
|
|
||||||
"application/javascript",
|
|
||||||
"application/x-python",
|
|
||||||
"application/x-sh",
|
|
||||||
]
|
|
||||||
return any(mime_type.startswith(t) for t in text_types)
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Authentication required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
file_id: Optional[str] = kwargs.get("file_id")
|
|
||||||
path: Optional[str] = kwargs.get("path")
|
|
||||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
|
||||||
|
|
||||||
if not file_id and not path:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide either file_id or path",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Pass session_id for session-scoped file access
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
|
||||||
|
|
||||||
# Get file info
|
|
||||||
if file_id:
|
|
||||||
file_info = await manager.get_file_info(file_id)
|
|
||||||
if file_info is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File not found: {file_id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
target_file_id = file_id
|
|
||||||
else:
|
|
||||||
# path is guaranteed to be non-None here due to the check above
|
|
||||||
assert path is not None
|
|
||||||
file_info = await manager.get_file_info_by_path(path)
|
|
||||||
if file_info is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File not found at path: {path}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
target_file_id = file_info.id
|
|
||||||
|
|
||||||
# Decide whether to return inline content or metadata+URL
|
|
||||||
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
|
||||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
|
||||||
|
|
||||||
# Return inline content for small text files (unless force_download_url)
|
|
||||||
if is_small_file and is_text_file and not force_download_url:
|
|
||||||
content = await manager.read_file_by_id(target_file_id)
|
|
||||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
|
||||||
|
|
||||||
return WorkspaceFileContentResponse(
|
|
||||||
file_id=file_info.id,
|
|
||||||
name=file_info.name,
|
|
||||||
path=file_info.path,
|
|
||||||
mime_type=file_info.mimeType,
|
|
||||||
content_base64=content_b64,
|
|
||||||
message=f"Successfully read file: {file_info.name}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return metadata + workspace:// reference for large or binary files
|
|
||||||
# This prevents context bloat (100KB file = ~133KB as base64)
|
|
||||||
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
|
||||||
download_url = f"workspace://{target_file_id}"
|
|
||||||
|
|
||||||
# Generate preview for text files
|
|
||||||
preview: str | None = None
|
|
||||||
if is_text_file:
|
|
||||||
try:
|
|
||||||
content = await manager.read_file_by_id(target_file_id)
|
|
||||||
preview_text = content[: self.PREVIEW_SIZE].decode(
|
|
||||||
"utf-8", errors="replace"
|
|
||||||
)
|
|
||||||
if len(content) > self.PREVIEW_SIZE:
|
|
||||||
preview_text += "..."
|
|
||||||
preview = preview_text
|
|
||||||
except Exception:
|
|
||||||
pass # Preview is optional
|
|
||||||
|
|
||||||
return WorkspaceFileMetadataResponse(
|
|
||||||
file_id=file_info.id,
|
|
||||||
name=file_info.name,
|
|
||||||
path=file_info.path,
|
|
||||||
mime_type=file_info.mimeType,
|
|
||||||
size_bytes=file_info.sizeBytes,
|
|
||||||
download_url=download_url,
|
|
||||||
preview=preview,
|
|
||||||
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to read workspace file: {str(e)}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WriteWorkspaceFileTool(BaseTool):
|
|
||||||
"""Tool for writing files to workspace."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "write_workspace_file"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Write or create a file in the user's workspace. "
|
|
||||||
"Provide the content as a base64-encoded string. "
|
|
||||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
|
||||||
"Files are saved to the current session's folder by default. "
|
|
||||||
"Use /sessions/<session_id>/... for cross-session access."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"filename": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Name for the file (e.g., 'report.pdf')",
|
|
||||||
},
|
|
||||||
"content_base64": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Base64-encoded file content",
|
|
||||||
},
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Optional virtual path where to save the file "
|
|
||||||
"(e.g., '/documents/report.pdf'). "
|
|
||||||
"Defaults to '/{filename}'. Scoped to current session."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"mime_type": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Optional MIME type of the file. "
|
|
||||||
"Auto-detected from filename if not provided."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"overwrite": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["filename", "content_base64"],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Authentication required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
filename: str = kwargs.get("filename", "")
|
|
||||||
content_b64: str = kwargs.get("content_base64", "")
|
|
||||||
path: Optional[str] = kwargs.get("path")
|
|
||||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
|
||||||
overwrite: bool = kwargs.get("overwrite", False)
|
|
||||||
|
|
||||||
if not filename:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide a filename",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not content_b64:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide content_base64",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode content
|
|
||||||
try:
|
|
||||||
content = base64.b64decode(content_b64)
|
|
||||||
except Exception:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Invalid base64-encoded content",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check size
|
|
||||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
|
||||||
if len(content) > max_file_size:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Virus scan
|
|
||||||
await scan_content_safe(content, filename=filename)
|
|
||||||
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Pass session_id for session-scoped file access
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
|
||||||
|
|
||||||
file_record = await manager.write_file(
|
|
||||||
content=content,
|
|
||||||
filename=filename,
|
|
||||||
path=path,
|
|
||||||
mime_type=mime_type,
|
|
||||||
overwrite=overwrite,
|
|
||||||
)
|
|
||||||
|
|
||||||
return WorkspaceWriteResponse(
|
|
||||||
file_id=file_record.id,
|
|
||||||
name=file_record.name,
|
|
||||||
path=file_record.path,
|
|
||||||
size_bytes=file_record.sizeBytes,
|
|
||||||
message=f"Successfully wrote file: {file_record.name}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to write workspace file: {str(e)}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteWorkspaceFileTool(BaseTool):
|
|
||||||
"""Tool for deleting files from workspace."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "delete_workspace_file"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Delete a file from the user's workspace. "
|
|
||||||
"Specify either file_id or path to identify the file. "
|
|
||||||
"Paths are scoped to the current session by default. "
|
|
||||||
"Use /sessions/<session_id>/... for cross-session access."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"file_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The file's unique ID (from list_workspace_files)",
|
|
||||||
},
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
|
||||||
"Scoped to current session by default."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": [], # At least one must be provided
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Authentication required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
file_id: Optional[str] = kwargs.get("file_id")
|
|
||||||
path: Optional[str] = kwargs.get("path")
|
|
||||||
|
|
||||||
if not file_id and not path:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide either file_id or path",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Pass session_id for session-scoped file access
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
|
||||||
|
|
||||||
# Determine the file_id to delete
|
|
||||||
target_file_id: str
|
|
||||||
if file_id:
|
|
||||||
target_file_id = file_id
|
|
||||||
else:
|
|
||||||
# path is guaranteed to be non-None here due to the check above
|
|
||||||
assert path is not None
|
|
||||||
file_info = await manager.get_file_info_by_path(path)
|
|
||||||
if file_info is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File not found at path: {path}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
target_file_id = file_info.id
|
|
||||||
|
|
||||||
success = await manager.delete_file(target_file_id)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File not found: {target_file_id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return WorkspaceDeleteResponse(
|
|
||||||
file_id=target_file_id,
|
|
||||||
success=True,
|
|
||||||
message="File deleted successfully",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to delete workspace file: {str(e)}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
|
|||||||
|
|
||||||
# Mock get_node_executions to return node_id mapping
|
# Mock get_node_executions to return node_id mapping
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.data.execution.get_node_executions"
|
"backend.api.features.executions.review.routes.get_node_executions"
|
||||||
)
|
)
|
||||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||||
mock_node_exec.node_exec_id = "test_node_123"
|
mock_node_exec.node_exec_id = "test_node_123"
|
||||||
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
|
|||||||
|
|
||||||
# Mock get_node_executions to return node_id mapping
|
# Mock get_node_executions to return node_id mapping
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.data.execution.get_node_executions"
|
"backend.api.features.executions.review.routes.get_node_executions"
|
||||||
)
|
)
|
||||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||||
mock_node_exec.node_exec_id = "node_exec_approved"
|
mock_node_exec.node_exec_id = "node_exec_approved"
|
||||||
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
|
|||||||
|
|
||||||
# Mock get_node_executions to return batch node data
|
# Mock get_node_executions to return batch node data
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.data.execution.get_node_executions"
|
"backend.api.features.executions.review.routes.get_node_executions"
|
||||||
)
|
)
|
||||||
# Create mock node executions for each review
|
# Create mock node executions for each review
|
||||||
mock_node_execs = []
|
mock_node_execs = []
|
||||||
|
|||||||
@@ -6,10 +6,15 @@ import autogpt_libs.auth as autogpt_auth_lib
|
|||||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
|
from backend.copilot.constants import (
|
||||||
|
is_copilot_synthetic_id,
|
||||||
|
parse_node_id_from_exec_id,
|
||||||
|
)
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionContext,
|
ExecutionContext,
|
||||||
ExecutionStatus,
|
ExecutionStatus,
|
||||||
get_graph_execution_meta,
|
get_graph_execution_meta,
|
||||||
|
get_node_executions,
|
||||||
)
|
)
|
||||||
from backend.data.graph import get_graph_settings
|
from backend.data.graph import get_graph_settings
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
@@ -22,6 +27,7 @@ from backend.data.human_review import (
|
|||||||
)
|
)
|
||||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||||
from backend.data.user import get_user_by_id
|
from backend.data.user import get_user_by_id
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
|
|
||||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||||
@@ -35,6 +41,38 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_node_ids(
|
||||||
|
node_exec_ids: list[str],
|
||||||
|
graph_exec_id: str,
|
||||||
|
is_copilot: bool,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Resolve node_exec_id -> node_id for auto-approval records.
|
||||||
|
|
||||||
|
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
|
||||||
|
Graph executions look up node_id from NodeExecution records.
|
||||||
|
"""
|
||||||
|
if not node_exec_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if is_copilot:
|
||||||
|
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
|
||||||
|
|
||||||
|
node_execs = await get_node_executions(
|
||||||
|
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||||
|
)
|
||||||
|
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for neid in node_exec_ids:
|
||||||
|
if neid in node_exec_map:
|
||||||
|
result[neid] = node_exec_map[neid]
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to resolve node_id for {neid}: Node execution not found."
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/pending",
|
"/pending",
|
||||||
summary="Get Pending Reviews",
|
summary="Get Pending Reviews",
|
||||||
@@ -109,14 +147,16 @@ async def list_pending_reviews_for_execution(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Verify user owns the graph execution before returning reviews
|
# Verify user owns the graph execution before returning reviews
|
||||||
graph_exec = await get_graph_execution_meta(
|
# (CoPilot synthetic IDs don't have graph execution records)
|
||||||
user_id=user_id, execution_id=graph_exec_id
|
if not is_copilot_synthetic_id(graph_exec_id):
|
||||||
)
|
graph_exec = await get_graph_execution_meta(
|
||||||
if not graph_exec:
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Graph execution #{graph_exec_id} not found",
|
|
||||||
)
|
)
|
||||||
|
if not graph_exec:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||||
|
|
||||||
@@ -159,30 +199,26 @@ async def process_review_action(
|
|||||||
)
|
)
|
||||||
|
|
||||||
graph_exec_id = next(iter(graph_exec_ids))
|
graph_exec_id = next(iter(graph_exec_ids))
|
||||||
|
is_copilot = is_copilot_synthetic_id(graph_exec_id)
|
||||||
|
|
||||||
# Validate execution status before processing reviews
|
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
|
||||||
graph_exec_meta = await get_graph_execution_meta(
|
if not is_copilot:
|
||||||
user_id=user_id, execution_id=graph_exec_id
|
graph_exec_meta = await get_graph_execution_meta(
|
||||||
)
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
|
|
||||||
if not graph_exec_meta:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Graph execution #{graph_exec_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only allow processing reviews if execution is paused for review
|
|
||||||
# or incomplete (partial execution with some reviews already processed)
|
|
||||||
if graph_exec_meta.status not in (
|
|
||||||
ExecutionStatus.REVIEW,
|
|
||||||
ExecutionStatus.INCOMPLETE,
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
|
||||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
|
||||||
f"Current status: {graph_exec_meta.status}",
|
|
||||||
)
|
)
|
||||||
|
if not graph_exec_meta:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
|
)
|
||||||
|
if graph_exec_meta.status not in (
|
||||||
|
ExecutionStatus.REVIEW,
|
||||||
|
ExecutionStatus.INCOMPLETE,
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
|
||||||
|
)
|
||||||
|
|
||||||
# Build review decisions map and track which reviews requested auto-approval
|
# Build review decisions map and track which reviews requested auto-approval
|
||||||
# Auto-approved reviews use original data (no modifications allowed)
|
# Auto-approved reviews use original data (no modifications allowed)
|
||||||
@@ -235,7 +271,7 @@ async def process_review_action(
|
|||||||
)
|
)
|
||||||
return (node_id, False)
|
return (node_id, False)
|
||||||
|
|
||||||
# Collect node_exec_ids that need auto-approval
|
# Collect node_exec_ids that need auto-approval and resolve their node_ids
|
||||||
node_exec_ids_needing_auto_approval = [
|
node_exec_ids_needing_auto_approval = [
|
||||||
node_exec_id
|
node_exec_id
|
||||||
for node_exec_id, review_result in updated_reviews.items()
|
for node_exec_id, review_result in updated_reviews.items()
|
||||||
@@ -243,29 +279,16 @@ async def process_review_action(
|
|||||||
and auto_approve_requests.get(node_exec_id, False)
|
and auto_approve_requests.get(node_exec_id, False)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Batch-fetch node executions to get node_ids
|
node_id_map = await _resolve_node_ids(
|
||||||
|
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deduplicate by node_id — one auto-approval per node
|
||||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||||
if node_exec_ids_needing_auto_approval:
|
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||||
from backend.data.execution import get_node_executions
|
node_id = node_id_map.get(node_exec_id)
|
||||||
|
if node_id and node_id not in nodes_needing_auto_approval:
|
||||||
node_execs = await get_node_executions(
|
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
|
||||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
|
||||||
)
|
|
||||||
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
|
||||||
|
|
||||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
|
||||||
node_exec = node_exec_map.get(node_exec_id)
|
|
||||||
if node_exec:
|
|
||||||
review_result = updated_reviews[node_exec_id]
|
|
||||||
# Use the first approved review for this node (deduplicate by node_id)
|
|
||||||
if node_exec.node_id not in nodes_needing_auto_approval:
|
|
||||||
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create auto-approval record for {node_exec_id}: "
|
|
||||||
f"Node execution not found. This may indicate a race condition "
|
|
||||||
f"or data inconsistency."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||||
auto_approval_results = await asyncio.gather(
|
auto_approval_results = await asyncio.gather(
|
||||||
@@ -280,13 +303,11 @@ async def process_review_action(
|
|||||||
auto_approval_failed_count = 0
|
auto_approval_failed_count = 0
|
||||||
for result in auto_approval_results:
|
for result in auto_approval_results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
# Unexpected exception during auto-approval creation
|
|
||||||
auto_approval_failed_count += 1
|
auto_approval_failed_count += 1
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Unexpected exception during auto-approval creation: {result}"
|
f"Unexpected exception during auto-approval creation: {result}"
|
||||||
)
|
)
|
||||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||||
# Auto-approval creation failed (returned False)
|
|
||||||
auto_approval_failed_count += 1
|
auto_approval_failed_count += 1
|
||||||
|
|
||||||
# Count results
|
# Count results
|
||||||
@@ -301,30 +322,31 @@ async def process_review_action(
|
|||||||
if review.status == ReviewStatus.REJECTED
|
if review.status == ReviewStatus.REJECTED
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
# Resume graph execution only for real graph executions (not CoPilot)
|
||||||
if updated_reviews:
|
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
|
||||||
|
if not is_copilot and updated_reviews:
|
||||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||||
|
|
||||||
if not still_has_pending:
|
if not still_has_pending:
|
||||||
# Get the graph_id from any processed review
|
|
||||||
first_review = next(iter(updated_reviews.values()))
|
first_review = next(iter(updated_reviews.values()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch user and settings to build complete execution context
|
|
||||||
user = await get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
settings = await get_graph_settings(
|
settings = await get_graph_settings(
|
||||||
user_id=user_id, graph_id=first_review.graph_id
|
user_id=user_id, graph_id=first_review.graph_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Preserve user's timezone preference when resuming execution
|
|
||||||
user_timezone = (
|
user_timezone = (
|
||||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|
||||||
execution_context = ExecutionContext(
|
execution_context = ExecutionContext(
|
||||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||||
user_timezone=user_timezone,
|
user_timezone=user_timezone,
|
||||||
|
workspace_id=workspace.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await add_graph_execution(
|
await add_graph_execution(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, List, Literal
|
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
|
||||||
|
|
||||||
from autogpt_libs.auth import get_user_id
|
from autogpt_libs.auth import get_user_id
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
@@ -14,7 +14,7 @@ from fastapi import (
|
|||||||
Security,
|
Security,
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||||
@@ -39,7 +39,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
|||||||
from backend.data.user import get_user_integrations
|
from backend.data.user import get_user_integrations
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.credentials_store import provider_matches
|
||||||
|
from backend.integrations.creds_manager import (
|
||||||
|
IntegrationCredentialsManager,
|
||||||
|
create_mcp_oauth_handler,
|
||||||
|
)
|
||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.integrations.webhooks import get_webhook_manager
|
from backend.integrations.webhooks import get_webhook_manager
|
||||||
@@ -102,9 +106,37 @@ class CredentialsMetaResponse(BaseModel):
|
|||||||
scopes: list[str] | None
|
scopes: list[str] | None
|
||||||
username: str | None
|
username: str | None
|
||||||
host: str | None = Field(
|
host: str | None = Field(
|
||||||
default=None, description="Host pattern for host-scoped credentials"
|
default=None,
|
||||||
|
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_provider(cls, data: Any) -> Any:
|
||||||
|
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
prov = data.get("provider", "")
|
||||||
|
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
||||||
|
member = prov.removeprefix("ProviderName.")
|
||||||
|
try:
|
||||||
|
data = {**data, "provider": ProviderName[member].value}
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_host(cred: Credentials) -> str | None:
|
||||||
|
"""Extract host from credential: HostScoped host or MCP server URL."""
|
||||||
|
if isinstance(cred, HostScopedCredentials):
|
||||||
|
return cred.host
|
||||||
|
if isinstance(cred, OAuth2Credentials) and cred.provider in (
|
||||||
|
ProviderName.MCP,
|
||||||
|
ProviderName.MCP.value,
|
||||||
|
"ProviderName.MCP",
|
||||||
|
):
|
||||||
|
return (cred.metadata or {}).get("mcp_server_url")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||||
async def callback(
|
async def callback(
|
||||||
@@ -179,9 +211,7 @@ async def callback(
|
|||||||
title=credentials.title,
|
title=credentials.title,
|
||||||
scopes=credentials.scopes,
|
scopes=credentials.scopes,
|
||||||
username=credentials.username,
|
username=credentials.username,
|
||||||
host=(
|
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||||
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -199,7 +229,7 @@ async def list_credentials(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
host=CredentialsMetaResponse.get_host(cred),
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -222,7 +252,7 @@ async def list_credentials_by_provider(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
host=CredentialsMetaResponse.get_host(cred),
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -322,7 +352,11 @@ async def delete_credentials(
|
|||||||
|
|
||||||
tokens_revoked = None
|
tokens_revoked = None
|
||||||
if isinstance(creds, OAuth2Credentials):
|
if isinstance(creds, OAuth2Credentials):
|
||||||
handler = _get_provider_oauth_handler(request, provider)
|
if provider_matches(provider.value, ProviderName.MCP.value):
|
||||||
|
# MCP uses dynamic per-server OAuth — create handler from metadata
|
||||||
|
handler = create_mcp_oauth_handler(creds)
|
||||||
|
else:
|
||||||
|
handler = _get_provider_oauth_handler(request, provider)
|
||||||
tokens_revoked = await handler.revoke_tokens(creds)
|
tokens_revoked = await handler.revoke_tokens(creds)
|
||||||
|
|
||||||
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ import prisma.enums
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import backend.api.features.store.exceptions
|
|
||||||
from backend.data.db import connect
|
from backend.data.db import connect
|
||||||
from backend.data.includes import library_agent_include
|
from backend.data.includes import library_agent_include
|
||||||
|
|
||||||
@@ -144,6 +143,7 @@ async def test_add_agent_to_library(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||||
|
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||||
return_value=mock_library_agent_data
|
return_value=mock_library_agent_data
|
||||||
@@ -178,7 +178,6 @@ async def test_add_agent_to_library(mocker):
|
|||||||
"agentGraphVersion": 1,
|
"agentGraphVersion": 1,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
include={"AgentGraph": True},
|
|
||||||
)
|
)
|
||||||
# Check that create was called with the expected data including settings
|
# Check that create was called with the expected data including settings
|
||||||
create_call_args = mock_library_agent.return_value.create.call_args
|
create_call_args = mock_library_agent.return_value.create.call_args
|
||||||
@@ -218,7 +217,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Call function and verify exception
|
# Call function and verify exception
|
||||||
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
with pytest.raises(db.NotFoundError):
|
||||||
await db.add_store_agent_to_library("version123", "test-user")
|
await db.add_store_agent_to_library("version123", "test-user")
|
||||||
|
|
||||||
# Verify mock called correctly
|
# Verify mock called correctly
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
class FolderValidationError(Exception):
|
||||||
|
"""Raised when folder operations fail validation."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FolderAlreadyExistsError(FolderValidationError):
|
||||||
|
"""Raised when a folder with the same name already exists in the location."""
|
||||||
|
|
||||||
|
pass
|
||||||
@@ -26,6 +26,95 @@ class LibraryAgentStatus(str, Enum):
|
|||||||
ERROR = "ERROR"
|
ERROR = "ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
# === Folder Models ===
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryFolder(pydantic.BaseModel):
|
||||||
|
"""Represents a folder for organizing library agents."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
name: str
|
||||||
|
icon: str | None = None
|
||||||
|
color: str | None = None
|
||||||
|
parent_id: str | None = None
|
||||||
|
created_at: datetime.datetime
|
||||||
|
updated_at: datetime.datetime
|
||||||
|
agent_count: int = 0 # Direct agents in folder
|
||||||
|
subfolder_count: int = 0 # Direct child folders
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_db(
|
||||||
|
folder: prisma.models.LibraryFolder,
|
||||||
|
agent_count: int = 0,
|
||||||
|
subfolder_count: int = 0,
|
||||||
|
) -> "LibraryFolder":
|
||||||
|
"""Factory method that constructs a LibraryFolder from a Prisma model."""
|
||||||
|
return LibraryFolder(
|
||||||
|
id=folder.id,
|
||||||
|
user_id=folder.userId,
|
||||||
|
name=folder.name,
|
||||||
|
icon=folder.icon,
|
||||||
|
color=folder.color,
|
||||||
|
parent_id=folder.parentId,
|
||||||
|
created_at=folder.createdAt,
|
||||||
|
updated_at=folder.updatedAt,
|
||||||
|
agent_count=agent_count,
|
||||||
|
subfolder_count=subfolder_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryFolderTree(LibraryFolder):
|
||||||
|
"""Folder with nested children for tree view."""
|
||||||
|
|
||||||
|
children: list["LibraryFolderTree"] = []
|
||||||
|
|
||||||
|
|
||||||
|
class FolderCreateRequest(pydantic.BaseModel):
|
||||||
|
"""Request model for creating a folder."""
|
||||||
|
|
||||||
|
name: str = pydantic.Field(..., min_length=1, max_length=100)
|
||||||
|
icon: str | None = None
|
||||||
|
color: str | None = pydantic.Field(
|
||||||
|
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
|
||||||
|
)
|
||||||
|
parent_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FolderUpdateRequest(pydantic.BaseModel):
|
||||||
|
"""Request model for updating a folder."""
|
||||||
|
|
||||||
|
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
|
||||||
|
icon: str | None = None
|
||||||
|
color: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FolderMoveRequest(pydantic.BaseModel):
|
||||||
|
"""Request model for moving a folder to a new parent."""
|
||||||
|
|
||||||
|
target_parent_id: str | None = None # None = move to root
|
||||||
|
|
||||||
|
|
||||||
|
class BulkMoveAgentsRequest(pydantic.BaseModel):
|
||||||
|
"""Request model for moving multiple agents to a folder."""
|
||||||
|
|
||||||
|
agent_ids: list[str]
|
||||||
|
folder_id: str | None = None # None = move to root
|
||||||
|
|
||||||
|
|
||||||
|
class FolderListResponse(pydantic.BaseModel):
|
||||||
|
"""Response schema for a list of folders."""
|
||||||
|
|
||||||
|
folders: list[LibraryFolder]
|
||||||
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
|
class FolderTreeResponse(pydantic.BaseModel):
|
||||||
|
"""Response schema for folder tree structure."""
|
||||||
|
|
||||||
|
tree: list[LibraryFolderTree]
|
||||||
|
|
||||||
|
|
||||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||||
"""Creator information for a marketplace listing."""
|
"""Creator information for a marketplace listing."""
|
||||||
|
|
||||||
@@ -76,7 +165,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
owner_user_id: str
|
|
||||||
|
|
||||||
image_url: str | None
|
image_url: str | None
|
||||||
|
|
||||||
@@ -117,9 +205,14 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="List of recent executions with status, score, and summary",
|
description="List of recent executions with status, score, and summary",
|
||||||
)
|
)
|
||||||
can_access_graph: bool
|
can_access_graph: bool = pydantic.Field(
|
||||||
|
description="Indicates whether the same user owns the corresponding graph"
|
||||||
|
)
|
||||||
is_latest_version: bool
|
is_latest_version: bool
|
||||||
is_favorite: bool
|
is_favorite: bool
|
||||||
|
folder_id: str | None = None
|
||||||
|
folder_name: str | None = None # Denormalized for display
|
||||||
|
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||||
@@ -232,7 +325,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id=agent.id,
|
id=agent.id,
|
||||||
graph_id=agent.agentGraphId,
|
graph_id=agent.agentGraphId,
|
||||||
graph_version=agent.agentGraphVersion,
|
graph_version=agent.agentGraphVersion,
|
||||||
owner_user_id=agent.userId,
|
|
||||||
image_url=agent.imageUrl,
|
image_url=agent.imageUrl,
|
||||||
creator_name=creator_name,
|
creator_name=creator_name,
|
||||||
creator_image_url=creator_image_url,
|
creator_image_url=creator_image_url,
|
||||||
@@ -259,6 +351,8 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
can_access_graph=can_access_graph,
|
can_access_graph=can_access_graph,
|
||||||
is_latest_version=is_latest_version,
|
is_latest_version=is_latest_version,
|
||||||
is_favorite=agent.isFavorite,
|
is_favorite=agent.isFavorite,
|
||||||
|
folder_id=agent.folderId,
|
||||||
|
folder_name=agent.Folder.name if agent.Folder else None,
|
||||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||||
settings=_parse_settings(agent.settings),
|
settings=_parse_settings(agent.settings),
|
||||||
marketplace_listing=marketplace_listing_data,
|
marketplace_listing=marketplace_listing_data,
|
||||||
@@ -470,3 +564,7 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
|||||||
settings: Optional[GraphSettings] = pydantic.Field(
|
settings: Optional[GraphSettings] = pydantic.Field(
|
||||||
default=None, description="User-specific settings for this library agent"
|
default=None, description="User-specific settings for this library agent"
|
||||||
)
|
)
|
||||||
|
folder_id: Optional[str] = pydantic.Field(
|
||||||
|
default=None,
|
||||||
|
description="Folder ID to move agent to (None to move to root)",
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
import fastapi
|
import fastapi
|
||||||
|
|
||||||
from .agents import router as agents_router
|
from .agents import router as agents_router
|
||||||
|
from .folders import router as folders_router
|
||||||
from .presets import router as presets_router
|
from .presets import router as presets_router
|
||||||
|
|
||||||
router = fastapi.APIRouter()
|
router = fastapi.APIRouter()
|
||||||
|
|
||||||
router.include_router(presets_router)
|
router.include_router(presets_router)
|
||||||
|
router.include_router(folders_router)
|
||||||
router.include_router(agents_router)
|
router.include_router(agents_router)
|
||||||
|
|||||||
@@ -41,6 +41,14 @@ async def list_library_agents(
|
|||||||
ge=1,
|
ge=1,
|
||||||
description="Number of agents per page (must be >= 1)",
|
description="Number of agents per page (must be >= 1)",
|
||||||
),
|
),
|
||||||
|
folder_id: Optional[str] = Query(
|
||||||
|
None,
|
||||||
|
description="Filter by folder ID",
|
||||||
|
),
|
||||||
|
include_root_only: bool = Query(
|
||||||
|
False,
|
||||||
|
description="Only return agents without a folder (root-level agents)",
|
||||||
|
),
|
||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Get all agents in the user's library (both created and saved).
|
Get all agents in the user's library (both created and saved).
|
||||||
@@ -51,6 +59,8 @@ async def list_library_agents(
|
|||||||
sort_by=sort_by,
|
sort_by=sort_by,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
|
folder_id=folder_id,
|
||||||
|
include_root_only=include_root_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -168,6 +178,7 @@ async def update_library_agent(
|
|||||||
is_favorite=payload.is_favorite,
|
is_favorite=payload.is_favorite,
|
||||||
is_archived=payload.is_archived,
|
is_archived=payload.is_archived,
|
||||||
settings=payload.settings,
|
settings=payload.settings,
|
||||||
|
folder_id=payload.folder_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,287 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import autogpt_libs.auth as autogpt_auth_lib
|
||||||
|
from fastapi import APIRouter, Query, Security, status
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from .. import db as library_db
|
||||||
|
from .. import model as library_model
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/folders",
|
||||||
|
tags=["library", "folders", "private"],
|
||||||
|
dependencies=[Security(autogpt_auth_lib.requires_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
summary="List Library Folders",
|
||||||
|
response_model=library_model.FolderListResponse,
|
||||||
|
responses={
|
||||||
|
200: {"description": "List of folders"},
|
||||||
|
500: {"description": "Server error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def list_folders(
|
||||||
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
|
parent_id: Optional[str] = Query(
|
||||||
|
None,
|
||||||
|
description="Filter by parent folder ID. If not provided, returns root-level folders.",
|
||||||
|
),
|
||||||
|
include_relations: bool = Query(
|
||||||
|
True,
|
||||||
|
description="Include agent and subfolder relations (for counts)",
|
||||||
|
),
|
||||||
|
) -> library_model.FolderListResponse:
|
||||||
|
"""
|
||||||
|
List folders for the authenticated user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: ID of the authenticated user.
|
||||||
|
parent_id: Optional parent folder ID to filter by.
|
||||||
|
include_relations: Whether to include agent and subfolder relations for counts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A FolderListResponse containing folders.
|
||||||
|
"""
|
||||||
|
folders = await library_db.list_folders(
|
||||||
|
user_id=user_id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
include_relations=include_relations,
|
||||||
|
)
|
||||||
|
return library_model.FolderListResponse(
|
||||||
|
folders=folders,
|
||||||
|
pagination=library_model.Pagination(
|
||||||
|
total_items=len(folders),
|
||||||
|
total_pages=1,
|
||||||
|
current_page=1,
|
||||||
|
page_size=len(folders),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tree",
|
||||||
|
summary="Get Folder Tree",
|
||||||
|
response_model=library_model.FolderTreeResponse,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Folder tree structure"},
|
||||||
|
500: {"description": "Server error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_folder_tree(
|
||||||
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
|
) -> library_model.FolderTreeResponse:
|
||||||
|
"""
|
||||||
|
Get the full folder tree for the authenticated user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: ID of the authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A FolderTreeResponse containing the nested folder structure.
|
||||||
|
"""
|
||||||
|
tree = await library_db.get_folder_tree(user_id=user_id)
|
||||||
|
return library_model.FolderTreeResponse(tree=tree)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{folder_id}",
|
||||||
|
summary="Get Folder",
|
||||||
|
response_model=library_model.LibraryFolder,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Folder details"},
|
||||||
|
404: {"description": "Folder not found"},
|
||||||
|
500: {"description": "Server error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_folder(
|
||||||
|
folder_id: str,
|
||||||
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
|
) -> library_model.LibraryFolder:
|
||||||
|
"""
|
||||||
|
Get a specific folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_id: ID of the folder to retrieve.
|
||||||
|
user_id: ID of the authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The requested LibraryFolder.
|
||||||
|
"""
|
||||||
|
return await library_db.get_folder(folder_id=folder_id, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"",
|
||||||
|
summary="Create Folder",
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
response_model=library_model.LibraryFolder,
|
||||||
|
responses={
|
||||||
|
201: {"description": "Folder created successfully"},
|
||||||
|
400: {"description": "Validation error"},
|
||||||
|
404: {"description": "Parent folder not found"},
|
||||||
|
409: {"description": "Folder name conflict"},
|
||||||
|
500: {"description": "Server error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def create_folder(
|
||||||
|
payload: library_model.FolderCreateRequest,
|
||||||
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
|
) -> library_model.LibraryFolder:
|
||||||
|
"""
|
||||||
|
Create a new folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: The folder creation request.
|
||||||
|
user_id: ID of the authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created LibraryFolder.
|
||||||
|
"""
|
||||||
|
return await library_db.create_folder(
|
||||||
|
user_id=user_id,
|
||||||
|
name=payload.name,
|
||||||
|
parent_id=payload.parent_id,
|
||||||
|
icon=payload.icon,
|
||||||
|
color=payload.color,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{folder_id}",
|
||||||
|
summary="Update Folder",
|
||||||
|
response_model=library_model.LibraryFolder,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Folder updated successfully"},
|
||||||
|
400: {"description": "Validation error"},
|
||||||
|
404: {"description": "Folder not found"},
|
||||||
|
409: {"description": "Folder name conflict"},
|
||||||
|
500: {"description": "Server error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def update_folder(
|
||||||
|
folder_id: str,
|
||||||
|
payload: library_model.FolderUpdateRequest,
|
||||||
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
|
) -> library_model.LibraryFolder:
|
||||||
|
"""
|
||||||
|
Update a folder's properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_id: ID of the folder to update.
|
||||||
|
payload: The folder update request.
|
||||||
|
user_id: ID of the authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated LibraryFolder.
|
||||||
|
"""
|
||||||
|
return await library_db.update_folder(
|
||||||
|
folder_id=folder_id,
|
||||||
|
user_id=user_id,
|
||||||
|
name=payload.name,
|
||||||
|
icon=payload.icon,
|
||||||
|
color=payload.color,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{folder_id}/move",
|
||||||
|
summary="Move Folder",
|
||||||
|
response_model=library_model.LibraryFolder,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Folder moved successfully"},
|
||||||
|
400: {"description": "Validation error (circular reference)"},
|
||||||
|
404: {"description": "Folder or target parent not found"},
|
||||||
|
409: {"description": "Folder name conflict in target location"},
|
||||||
|
500: {"description": "Server error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def move_folder(
|
||||||
|
folder_id: str,
|
||||||
|
payload: library_model.FolderMoveRequest,
|
||||||
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
|
) -> library_model.LibraryFolder:
|
||||||
|
"""
|
||||||
|
Move a folder to a new parent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_id: ID of the folder to move.
|
||||||
|
payload: The move request with target parent.
|
||||||
|
user_id: ID of the authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The moved LibraryFolder.
|
||||||
|
"""
|
||||||
|
return await library_db.move_folder(
|
||||||
|
folder_id=folder_id,
|
||||||
|
user_id=user_id,
|
||||||
|
target_parent_id=payload.target_parent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{folder_id}",
|
||||||
|
summary="Delete Folder",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
responses={
|
||||||
|
204: {"description": "Folder deleted successfully"},
|
||||||
|
404: {"description": "Folder not found"},
|
||||||
|
500: {"description": "Server error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def delete_folder(
|
||||||
|
folder_id: str,
|
||||||
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Soft-delete a folder and all its contents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_id: ID of the folder to delete.
|
||||||
|
user_id: ID of the authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
204 No Content if successful.
|
||||||
|
"""
|
||||||
|
await library_db.delete_folder(
|
||||||
|
folder_id=folder_id,
|
||||||
|
user_id=user_id,
|
||||||
|
soft_delete=True,
|
||||||
|
)
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
|
||||||
|
# === Bulk Agent Operations ===
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/agents/bulk-move",
|
||||||
|
summary="Bulk Move Agents",
|
||||||
|
response_model=list[library_model.LibraryAgent],
|
||||||
|
responses={
|
||||||
|
200: {"description": "Agents moved successfully"},
|
||||||
|
404: {"description": "Folder not found"},
|
||||||
|
500: {"description": "Server error"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def bulk_move_agents(
|
||||||
|
payload: library_model.BulkMoveAgentsRequest,
|
||||||
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
|
) -> list[library_model.LibraryAgent]:
|
||||||
|
"""
|
||||||
|
Move multiple agents to a folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: The bulk move request with agent IDs and target folder.
|
||||||
|
user_id: ID of the authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated LibraryAgents.
|
||||||
|
"""
|
||||||
|
return await library_db.bulk_move_agents_to_folder(
|
||||||
|
agent_ids=payload.agent_ids,
|
||||||
|
folder_id=payload.folder_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
@@ -42,7 +42,6 @@ async def test_get_library_agents_success(
|
|||||||
id="test-agent-1",
|
id="test-agent-1",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Test Agent 1",
|
name="Test Agent 1",
|
||||||
description="Test Description 1",
|
description="Test Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -67,7 +66,6 @@ async def test_get_library_agents_success(
|
|||||||
id="test-agent-2",
|
id="test-agent-2",
|
||||||
graph_id="test-agent-2",
|
graph_id="test-agent-2",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Test Agent 2",
|
name="Test Agent 2",
|
||||||
description="Test Description 2",
|
description="Test Description 2",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -115,6 +113,8 @@ async def test_get_library_agents_success(
|
|||||||
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=15,
|
page_size=15,
|
||||||
|
folder_id=None,
|
||||||
|
include_root_only=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -129,7 +129,6 @@ async def test_get_favorite_library_agents_success(
|
|||||||
id="test-agent-1",
|
id="test-agent-1",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Favorite Agent 1",
|
name="Favorite Agent 1",
|
||||||
description="Test Favorite Description 1",
|
description="Test Favorite Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -182,7 +181,6 @@ def test_add_agent_to_library_success(
|
|||||||
id="test-library-agent-id",
|
id="test-library-agent-id",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
owner_user_id=test_user_id,
|
|
||||||
name="Test Agent 1",
|
name="Test Agent 1",
|
||||||
description="Test Description 1",
|
description="Test Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
|
|||||||
511
autogpt_platform/backend/backend/api/features/mcp/routes.py
Normal file
511
autogpt_platform/backend/backend/api/features/mcp/routes.py
Normal file
@@ -0,0 +1,511 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) API routes.
|
||||||
|
|
||||||
|
Provides endpoints for MCP tool discovery and OAuth authentication so the
|
||||||
|
frontend can list available tools on an MCP server before placing a block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
from autogpt_libs.auth import get_user_id
|
||||||
|
from fastapi import Security
|
||||||
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from backend.api.features.integrations.router import CredentialsMetaResponse
|
||||||
|
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||||
|
from backend.blocks.mcp.helpers import (
|
||||||
|
auto_lookup_mcp_credential,
|
||||||
|
normalize_mcp_url,
|
||||||
|
server_host,
|
||||||
|
)
|
||||||
|
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.request import HTTPClientError, Requests, validate_url_host
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
router = fastapi.APIRouter(tags=["mcp"])
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
|
||||||
|
|
||||||
|
# ====================== Tool Discovery ====================== #
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoverToolsRequest(BaseModel):
|
||||||
|
"""Request to discover tools on an MCP server."""
|
||||||
|
|
||||||
|
server_url: str = Field(description="URL of the MCP server")
|
||||||
|
auth_token: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional Bearer token for authenticated MCP servers",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolResponse(BaseModel):
|
||||||
|
"""A single MCP tool returned by discovery."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoverToolsResponse(BaseModel):
|
||||||
|
"""Response containing the list of tools available on an MCP server."""
|
||||||
|
|
||||||
|
tools: list[MCPToolResponse]
|
||||||
|
server_name: str | None = None
|
||||||
|
protocol_version: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/discover-tools",
|
||||||
|
summary="Discover available tools on an MCP server",
|
||||||
|
response_model=DiscoverToolsResponse,
|
||||||
|
)
|
||||||
|
async def discover_tools(
|
||||||
|
request: DiscoverToolsRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> DiscoverToolsResponse:
|
||||||
|
"""
|
||||||
|
Connect to an MCP server and return its available tools.
|
||||||
|
|
||||||
|
If the user has a stored MCP credential for this server URL, it will be
|
||||||
|
used automatically — no need to pass an explicit auth token.
|
||||||
|
"""
|
||||||
|
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||||
|
try:
|
||||||
|
await validate_url_host(request.server_url)
|
||||||
|
except ValueError as e:
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||||
|
|
||||||
|
auth_token = request.auth_token
|
||||||
|
|
||||||
|
# Auto-use stored MCP credential when no explicit token is provided.
|
||||||
|
if not auth_token:
|
||||||
|
best_cred = await auto_lookup_mcp_credential(
|
||||||
|
user_id, normalize_mcp_url(request.server_url)
|
||||||
|
)
|
||||||
|
if best_cred:
|
||||||
|
auth_token = best_cred.access_token.get_secret_value()
|
||||||
|
|
||||||
|
client = MCPClient(request.server_url, auth_token=auth_token)
|
||||||
|
|
||||||
|
try:
|
||||||
|
init_result = await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
except HTTPClientError as e:
|
||||||
|
if e.status_code in (401, 403):
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="This MCP server requires authentication. "
|
||||||
|
"Please provide a valid auth token.",
|
||||||
|
)
|
||||||
|
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||||
|
except MCPClientError as e:
|
||||||
|
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Failed to connect to MCP server: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return DiscoverToolsResponse(
|
||||||
|
tools=[
|
||||||
|
MCPToolResponse(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
input_schema=t.input_schema,
|
||||||
|
)
|
||||||
|
for t in tools
|
||||||
|
],
|
||||||
|
server_name=(
|
||||||
|
init_result.get("serverInfo", {}).get("name")
|
||||||
|
or server_host(request.server_url)
|
||||||
|
or "MCP"
|
||||||
|
),
|
||||||
|
protocol_version=init_result.get("protocolVersion"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================== OAuth Flow ======================== #
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthLoginRequest(BaseModel):
|
||||||
|
"""Request to start an OAuth flow for an MCP server."""
|
||||||
|
|
||||||
|
server_url: str = Field(description="URL of the MCP server that requires OAuth")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthLoginResponse(BaseModel):
|
||||||
|
"""Response with the OAuth login URL for the user to authenticate."""
|
||||||
|
|
||||||
|
login_url: str
|
||||||
|
state_token: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/oauth/login",
|
||||||
|
summary="Initiate OAuth login for an MCP server",
|
||||||
|
)
|
||||||
|
async def mcp_oauth_login(
|
||||||
|
request: MCPOAuthLoginRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> MCPOAuthLoginResponse:
|
||||||
|
"""
|
||||||
|
Discover OAuth metadata from the MCP server and return a login URL.
|
||||||
|
|
||||||
|
1. Discovers the protected-resource metadata (RFC 9728)
|
||||||
|
2. Fetches the authorization server metadata (RFC 8414)
|
||||||
|
3. Performs Dynamic Client Registration (RFC 7591) if available
|
||||||
|
4. Returns the authorization URL for the frontend to open in a popup
|
||||||
|
"""
|
||||||
|
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||||
|
try:
|
||||||
|
await validate_url_host(request.server_url)
|
||||||
|
except ValueError as e:
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||||
|
|
||||||
|
# Normalize the URL so that credentials stored here are matched consistently
|
||||||
|
# by auto_lookup_mcp_credential (which also uses normalized URLs).
|
||||||
|
server_url = normalize_mcp_url(request.server_url)
|
||||||
|
client = MCPClient(server_url)
|
||||||
|
|
||||||
|
# Step 1: Discover protected-resource metadata (RFC 9728)
|
||||||
|
protected_resource = await client.discover_auth()
|
||||||
|
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
if protected_resource and protected_resource.get("authorization_servers"):
|
||||||
|
auth_server_url = protected_resource["authorization_servers"][0]
|
||||||
|
resource_url = protected_resource.get("resource", server_url)
|
||||||
|
|
||||||
|
# Validate the auth server URL from metadata to prevent SSRF.
|
||||||
|
try:
|
||||||
|
await validate_url_host(auth_server_url)
|
||||||
|
except ValueError as e:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid authorization server URL in metadata: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2a: Discover auth-server metadata (RFC 8414)
|
||||||
|
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||||
|
else:
|
||||||
|
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
|
||||||
|
# and serve OAuth metadata directly without protected-resource metadata.
|
||||||
|
# Don't assume a resource_url — omitting it lets the auth server choose
|
||||||
|
# the correct audience for the token (RFC 8707 resource is optional).
|
||||||
|
resource_url = None
|
||||||
|
metadata = await client.discover_auth_server_metadata(server_url)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not metadata
|
||||||
|
or "authorization_endpoint" not in metadata
|
||||||
|
or "token_endpoint" not in metadata
|
||||||
|
):
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="This MCP server does not advertise OAuth support. "
|
||||||
|
"You may need to provide an auth token manually.",
|
||||||
|
)
|
||||||
|
|
||||||
|
authorize_url = metadata["authorization_endpoint"]
|
||||||
|
token_url = metadata["token_endpoint"]
|
||||||
|
registration_endpoint = metadata.get("registration_endpoint")
|
||||||
|
revoke_url = metadata.get("revocation_endpoint")
|
||||||
|
|
||||||
|
# Step 3: Dynamic Client Registration (RFC 7591) if available
|
||||||
|
frontend_base_url = settings.config.frontend_base_url
|
||||||
|
if not frontend_base_url:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Frontend base URL is not configured.",
|
||||||
|
)
|
||||||
|
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||||
|
|
||||||
|
client_id = ""
|
||||||
|
client_secret = ""
|
||||||
|
if registration_endpoint:
|
||||||
|
# Validate the registration endpoint to prevent SSRF via metadata.
|
||||||
|
try:
|
||||||
|
await validate_url_host(registration_endpoint)
|
||||||
|
except ValueError:
|
||||||
|
pass # Skip registration, fall back to default client_id
|
||||||
|
else:
|
||||||
|
reg_result = await _register_mcp_client(
|
||||||
|
registration_endpoint, redirect_uri, server_url
|
||||||
|
)
|
||||||
|
if reg_result:
|
||||||
|
client_id = reg_result.get("client_id", "")
|
||||||
|
client_secret = reg_result.get("client_secret", "")
|
||||||
|
|
||||||
|
if not client_id:
|
||||||
|
client_id = "autogpt-platform"
|
||||||
|
|
||||||
|
# Step 4: Store state token with OAuth metadata for the callback
|
||||||
|
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
|
||||||
|
"scopes_supported", []
|
||||||
|
)
|
||||||
|
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||||
|
user_id,
|
||||||
|
ProviderName.MCP.value,
|
||||||
|
scopes,
|
||||||
|
state_metadata={
|
||||||
|
"authorize_url": authorize_url,
|
||||||
|
"token_url": token_url,
|
||||||
|
"revoke_url": revoke_url,
|
||||||
|
"resource_url": resource_url,
|
||||||
|
"server_url": server_url,
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Build and return the login URL
|
||||||
|
handler = MCPOAuthHandler(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
authorize_url=authorize_url,
|
||||||
|
token_url=token_url,
|
||||||
|
resource_url=resource_url,
|
||||||
|
)
|
||||||
|
login_url = handler.get_login_url(
|
||||||
|
scopes, state_token, code_challenge=code_challenge
|
||||||
|
)
|
||||||
|
|
||||||
|
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthCallbackRequest(BaseModel):
|
||||||
|
"""Request to exchange an OAuth code for tokens."""
|
||||||
|
|
||||||
|
code: str = Field(description="Authorization code from OAuth callback")
|
||||||
|
state_token: str = Field(description="State token for CSRF verification")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthCallbackResponse(BaseModel):
|
||||||
|
"""Response after successfully storing OAuth credentials."""
|
||||||
|
|
||||||
|
credential_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
summary="Exchange OAuth code for MCP tokens",
|
||||||
|
)
|
||||||
|
async def mcp_oauth_callback(
|
||||||
|
request: MCPOAuthCallbackRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> CredentialsMetaResponse:
|
||||||
|
"""
|
||||||
|
Exchange the authorization code for tokens and store the credential.
|
||||||
|
|
||||||
|
The frontend calls this after receiving the OAuth code from the popup.
|
||||||
|
On success, subsequent ``/discover-tools`` calls for the same server URL
|
||||||
|
will automatically use the stored credential.
|
||||||
|
"""
|
||||||
|
valid_state = await creds_manager.store.verify_state_token(
|
||||||
|
user_id, request.state_token, ProviderName.MCP.value
|
||||||
|
)
|
||||||
|
if not valid_state:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid or expired state token.",
|
||||||
|
)
|
||||||
|
|
||||||
|
meta = valid_state.state_metadata
|
||||||
|
frontend_base_url = settings.config.frontend_base_url
|
||||||
|
if not frontend_base_url:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Frontend base URL is not configured.",
|
||||||
|
)
|
||||||
|
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||||
|
|
||||||
|
handler = MCPOAuthHandler(
|
||||||
|
client_id=meta["client_id"],
|
||||||
|
client_secret=meta.get("client_secret", ""),
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
authorize_url=meta["authorize_url"],
|
||||||
|
token_url=meta["token_url"],
|
||||||
|
revoke_url=meta.get("revoke_url"),
|
||||||
|
resource_url=meta.get("resource_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials = await handler.exchange_code_for_tokens(
|
||||||
|
request.code, valid_state.scopes, valid_state.code_verifier
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"OAuth token exchange failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enrich credential metadata for future lookup and token refresh
|
||||||
|
if credentials.metadata is None:
|
||||||
|
credentials.metadata = {}
|
||||||
|
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
||||||
|
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
||||||
|
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
||||||
|
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
||||||
|
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
||||||
|
|
||||||
|
hostname = server_host(meta["server_url"])
|
||||||
|
credentials.title = f"MCP: {hostname}"
|
||||||
|
|
||||||
|
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
||||||
|
try:
|
||||||
|
old_creds = await creds_manager.store.get_creds_by_provider(
|
||||||
|
user_id, ProviderName.MCP.value
|
||||||
|
)
|
||||||
|
for old in old_creds:
|
||||||
|
if (
|
||||||
|
isinstance(old, OAuth2Credentials)
|
||||||
|
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
|
||||||
|
):
|
||||||
|
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
||||||
|
logger.info(
|
||||||
|
"Removed old MCP credential %s for %s",
|
||||||
|
old.id,
|
||||||
|
server_host(meta["server_url"]),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
||||||
|
|
||||||
|
await creds_manager.create(user_id, credentials)
|
||||||
|
|
||||||
|
return CredentialsMetaResponse(
|
||||||
|
id=credentials.id,
|
||||||
|
provider=credentials.provider,
|
||||||
|
type=credentials.type,
|
||||||
|
title=credentials.title,
|
||||||
|
scopes=credentials.scopes,
|
||||||
|
username=credentials.username,
|
||||||
|
host=credentials.metadata.get("mcp_server_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================== Bearer Token ======================== #
|
||||||
|
|
||||||
|
|
||||||
|
class MCPStoreTokenRequest(BaseModel):
|
||||||
|
"""Request to store a bearer token for an MCP server that doesn't support OAuth."""
|
||||||
|
|
||||||
|
server_url: str = Field(
|
||||||
|
description="MCP server URL the token authenticates against"
|
||||||
|
)
|
||||||
|
token: SecretStr = Field(
|
||||||
|
min_length=1, description="Bearer token / API key for the MCP server"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/token",
|
||||||
|
summary="Store a bearer token for an MCP server",
|
||||||
|
)
|
||||||
|
async def mcp_store_token(
|
||||||
|
request: MCPStoreTokenRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> CredentialsMetaResponse:
|
||||||
|
"""
|
||||||
|
Store a manually provided bearer token as an MCP credential.
|
||||||
|
|
||||||
|
Used by the Copilot MCPSetupCard when the server doesn't support the MCP
|
||||||
|
OAuth discovery flow (returns 400 from /oauth/login). Subsequent
|
||||||
|
``run_mcp_tool`` calls will automatically pick up the token via
|
||||||
|
``_auto_lookup_credential``.
|
||||||
|
"""
|
||||||
|
token = request.token.get_secret_value().strip()
|
||||||
|
if not token:
|
||||||
|
raise fastapi.HTTPException(status_code=422, detail="Token must not be blank.")
|
||||||
|
|
||||||
|
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
||||||
|
try:
|
||||||
|
await validate_url_host(request.server_url)
|
||||||
|
except ValueError as e:
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
||||||
|
|
||||||
|
# Normalize URL so trailing-slash variants match existing credentials.
|
||||||
|
server_url = normalize_mcp_url(request.server_url)
|
||||||
|
hostname = server_host(server_url)
|
||||||
|
|
||||||
|
# Collect IDs of old credentials to clean up after successful create.
|
||||||
|
old_cred_ids: list[str] = []
|
||||||
|
try:
|
||||||
|
old_creds = await creds_manager.store.get_creds_by_provider(
|
||||||
|
user_id, ProviderName.MCP.value
|
||||||
|
)
|
||||||
|
old_cred_ids = [
|
||||||
|
old.id
|
||||||
|
for old in old_creds
|
||||||
|
if isinstance(old, OAuth2Credentials)
|
||||||
|
and normalize_mcp_url((old.metadata or {}).get("mcp_server_url", ""))
|
||||||
|
== server_url
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not query old MCP token credentials", exc_info=True)
|
||||||
|
|
||||||
|
credentials = OAuth2Credentials(
|
||||||
|
provider=ProviderName.MCP.value,
|
||||||
|
title=f"MCP: {hostname}",
|
||||||
|
access_token=SecretStr(token),
|
||||||
|
scopes=[],
|
||||||
|
metadata={"mcp_server_url": server_url},
|
||||||
|
)
|
||||||
|
await creds_manager.create(user_id, credentials)
|
||||||
|
|
||||||
|
# Only delete old credentials after the new one is safely stored.
|
||||||
|
for old_id in old_cred_ids:
|
||||||
|
try:
|
||||||
|
await creds_manager.store.delete_creds_by_id(user_id, old_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not clean up old MCP token credential", exc_info=True)
|
||||||
|
|
||||||
|
return CredentialsMetaResponse(
|
||||||
|
id=credentials.id,
|
||||||
|
provider=credentials.provider,
|
||||||
|
type=credentials.type,
|
||||||
|
title=credentials.title,
|
||||||
|
scopes=credentials.scopes,
|
||||||
|
username=credentials.username,
|
||||||
|
host=hostname,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================== Helpers ======================== #
|
||||||
|
|
||||||
|
|
||||||
|
async def _register_mcp_client(
|
||||||
|
registration_endpoint: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
server_url: str,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
|
||||||
|
try:
|
||||||
|
response = await Requests(raise_for_status=True).post(
|
||||||
|
registration_endpoint,
|
||||||
|
json={
|
||||||
|
"client_name": "AutoGPT Platform",
|
||||||
|
"redirect_uris": [redirect_uri],
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"token_endpoint_auth_method": "client_secret_post",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
if isinstance(data, dict) and "client_id" in data:
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Dynamic client registration failed for %s: %s", server_host(server_url), e
|
||||||
|
)
|
||||||
|
return None
|
||||||
572
autogpt_platform/backend/backend/api/features/mcp/test_routes.py
Normal file
572
autogpt_platform/backend/backend/api/features/mcp/test_routes.py
Normal file
@@ -0,0 +1,572 @@
|
|||||||
|
"""Tests for MCP API routes.
|
||||||
|
|
||||||
|
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
|
||||||
|
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from autogpt_libs.auth import get_user_id
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.api.features.mcp.routes import router
|
||||||
|
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
from backend.util.request import HTTPClientError
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="module")
|
||||||
|
async def client():
|
||||||
|
transport = httpx.ASGITransport(app=app)
|
||||||
|
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _bypass_ssrf_validation():
|
||||||
|
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.mcp.routes.validate_url_host",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscoverTools:
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_tools_success(self, client):
|
||||||
|
mock_tools = [
|
||||||
|
MCPTool(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get weather for a city",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MCPTool(
|
||||||
|
name="add_numbers",
|
||||||
|
description="Add two numbers",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "number"},
|
||||||
|
"b": {"type": "number"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"serverInfo": {"name": "test-server"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
instance.list_tools = AsyncMock(return_value=mock_tools)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["tools"]) == 2
|
||||||
|
assert data["tools"][0]["name"] == "get_weather"
|
||||||
|
assert data["tools"][1]["name"] == "add_numbers"
|
||||||
|
assert data["server_name"] == "test-server"
|
||||||
|
assert data["protocol_version"] == "2025-03-26"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_tools_with_auth_token(self, client):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||||
|
)
|
||||||
|
instance.list_tools = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
"auth_token": "my-secret-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
MockClient.assert_called_once_with(
|
||||||
|
"https://mcp.example.com/mcp",
|
||||||
|
auth_token="my-secret-token",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
||||||
|
"""When no explicit token is given, stored MCP credentials are used."""
|
||||||
|
stored_cred = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
title="MCP: example.com",
|
||||||
|
access_token=SecretStr("stored-token-123"),
|
||||||
|
refresh_token=None,
|
||||||
|
access_token_expires_at=None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=[],
|
||||||
|
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=stored_cred,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||||
|
)
|
||||||
|
instance.list_tools = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
MockClient.assert_called_once_with(
|
||||||
|
"https://mcp.example.com/mcp",
|
||||||
|
auth_token="stored-token-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_tools_mcp_error(self, client):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
side_effect=MCPClientError("Connection refused")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://bad-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 502
|
||||||
|
assert "Connection refused" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_tools_generic_error(self, client):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://timeout.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 502
|
||||||
|
assert "Failed to connect" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_tools_auth_required(self, client):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "requires authentication" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_tools_forbidden(self, client):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "requires authentication" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_tools_missing_url(self, client):
|
||||||
|
response = await client.post("/discover-tools", json={})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthLogin:
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_oauth_login_success(self, client):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.mcp.routes._register_mcp_client"
|
||||||
|
) as mock_register,
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.discover_auth = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_servers": ["https://auth.sentry.io"],
|
||||||
|
"resource": "https://mcp.sentry.dev/mcp",
|
||||||
|
"scopes_supported": ["openid"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
instance.discover_auth_server_metadata = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_endpoint": "https://auth.sentry.io/authorize",
|
||||||
|
"token_endpoint": "https://auth.sentry.io/token",
|
||||||
|
"registration_endpoint": "https://auth.sentry.io/register",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_register.return_value = {
|
||||||
|
"client_id": "registered-client-id",
|
||||||
|
"client_secret": "registered-secret",
|
||||||
|
}
|
||||||
|
mock_cm.store.store_state_token = AsyncMock(
|
||||||
|
return_value=("state-token-123", "code-challenge-abc")
|
||||||
|
)
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "login_url" in data
|
||||||
|
assert data["state_token"] == "state-token-123"
|
||||||
|
assert "auth.sentry.io/authorize" in data["login_url"]
|
||||||
|
assert "registered-client-id" in data["login_url"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_oauth_login_no_oauth_support(self, client):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.discover_auth = AsyncMock(return_value=None)
|
||||||
|
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "https://simple-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "does not advertise OAuth" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_oauth_login_fallback_to_public_client(self, client):
|
||||||
|
"""When DCR is unavailable, falls back to default public client ID."""
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.discover_auth = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_servers": ["https://auth.example.com"],
|
||||||
|
"resource": "https://mcp.example.com/mcp",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
instance.discover_auth_server_metadata = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||||
|
"token_endpoint": "https://auth.example.com/token",
|
||||||
|
# No registration_endpoint
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_cm.store.store_state_token = AsyncMock(
|
||||||
|
return_value=("state-abc", "challenge-xyz")
|
||||||
|
)
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "autogpt-platform" in data["login_url"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthCallback:
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_oauth_callback_success(self, client):
|
||||||
|
mock_creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
title=None,
|
||||||
|
access_token=SecretStr("access-token-xyz"),
|
||||||
|
refresh_token=None,
|
||||||
|
access_token_expires_at=None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=[],
|
||||||
|
metadata={
|
||||||
|
"mcp_token_url": "https://auth.sentry.io/token",
|
||||||
|
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||||
|
):
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
|
||||||
|
# Mock state verification
|
||||||
|
mock_state = AsyncMock()
|
||||||
|
mock_state.state_metadata = {
|
||||||
|
"authorize_url": "https://auth.sentry.io/authorize",
|
||||||
|
"token_url": "https://auth.sentry.io/token",
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-secret",
|
||||||
|
"server_url": "https://mcp.sentry.dev/mcp",
|
||||||
|
}
|
||||||
|
mock_state.scopes = ["openid"]
|
||||||
|
mock_state.code_verifier = "verifier-123"
|
||||||
|
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||||
|
mock_cm.create = AsyncMock()
|
||||||
|
|
||||||
|
handler_instance = MockHandler.return_value
|
||||||
|
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||||
|
return_value=mock_creds
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock old credential cleanup
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert data["provider"] == "mcp"
|
||||||
|
assert data["type"] == "oauth2"
|
||||||
|
mock_cm.create.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_oauth_callback_invalid_state(self, client):
|
||||||
|
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||||
|
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
json={"code": "auth-code", "state_token": "bad-state"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Invalid or expired" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_oauth_callback_token_exchange_fails(self, client):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||||
|
):
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
mock_state = AsyncMock()
|
||||||
|
mock_state.state_metadata = {
|
||||||
|
"authorize_url": "https://auth.example.com/authorize",
|
||||||
|
"token_url": "https://auth.example.com/token",
|
||||||
|
"client_id": "cid",
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
}
|
||||||
|
mock_state.scopes = []
|
||||||
|
mock_state.code_verifier = "v"
|
||||||
|
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||||
|
|
||||||
|
handler_instance = MockHandler.return_value
|
||||||
|
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||||
|
side_effect=RuntimeError("Token exchange failed")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
json={"code": "bad-code", "state_token": "state"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "token exchange failed" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestStoreToken:
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_store_token_success(self, client):
|
||||||
|
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||||
|
mock_cm.create = AsyncMock()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/token",
|
||||||
|
json={
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
"token": "my-api-key-123",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["provider"] == "mcp"
|
||||||
|
assert data["type"] == "oauth2"
|
||||||
|
assert data["host"] == "mcp.example.com"
|
||||||
|
mock_cm.create.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_store_token_blank_rejected(self, client):
|
||||||
|
"""Blank token string (after stripping) should return 422."""
|
||||||
|
response = await client.post(
|
||||||
|
"/token",
|
||||||
|
json={
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
"token": " ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Pydantic min_length=1 catches the whitespace-only token
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_store_token_replaces_old_credential(self, client):
|
||||||
|
old_cred = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
title="MCP: mcp.example.com",
|
||||||
|
access_token=SecretStr("old-token"),
|
||||||
|
scopes=[],
|
||||||
|
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[old_cred])
|
||||||
|
mock_cm.create = AsyncMock()
|
||||||
|
mock_cm.store.delete_creds_by_id = AsyncMock()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/token",
|
||||||
|
json={
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
"token": "new-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_cm.store.delete_creds_by_id.assert_called_once_with(
|
||||||
|
"test-user-id", old_cred.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSRFValidation:
|
||||||
|
"""Verify that validate_url_host is enforced on all endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_tools_ssrf_blocked(self, client):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.mcp.routes.validate_url_host",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=ValueError("blocked loopback"),
|
||||||
|
):
|
||||||
|
response = await client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "http://localhost/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "blocked loopback" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_oauth_login_ssrf_blocked(self, client):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.mcp.routes.validate_url_host",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=ValueError("blocked private IP"),
|
||||||
|
):
|
||||||
|
response = await client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "http://10.0.0.1/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "blocked private ip" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_store_token_ssrf_blocked(self, client):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.mcp.routes.validate_url_host",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=ValueError("blocked loopback"),
|
||||||
|
):
|
||||||
|
response = await client.post(
|
||||||
|
"/token",
|
||||||
|
json={
|
||||||
|
"server_url": "http://127.0.0.1/mcp",
|
||||||
|
"token": "some-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "blocked loopback" in response.json()["detail"].lower()
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
|
|
||||||
from . import db as store_db
|
from . import db as store_db
|
||||||
@@ -23,7 +21,7 @@ def clear_all_caches():
|
|||||||
async def _get_cached_store_agents(
|
async def _get_cached_store_agents(
|
||||||
featured: bool,
|
featured: bool,
|
||||||
creator: str | None,
|
creator: str | None,
|
||||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
|
sorted_by: store_db.StoreAgentsSortOptions | None,
|
||||||
search_query: str | None,
|
search_query: str | None,
|
||||||
category: str | None,
|
category: str | None,
|
||||||
page: int,
|
page: int,
|
||||||
@@ -57,7 +55,7 @@ async def _get_cached_agent_details(
|
|||||||
async def _get_cached_store_creators(
|
async def _get_cached_store_creators(
|
||||||
featured: bool,
|
featured: bool,
|
||||||
search_query: str | None,
|
search_query: str | None,
|
||||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
|
sorted_by: store_db.StoreCreatorsSortOptions | None,
|
||||||
page: int,
|
page: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
):
|
):
|
||||||
@@ -75,4 +73,4 @@ async def _get_cached_store_creators(
|
|||||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||||
async def _get_cached_creator_details(username: str):
|
async def _get_cached_creator_details(username: str):
|
||||||
"""Cached helper to get creator details."""
|
"""Cached helper to get creator details."""
|
||||||
return await store_db.get_store_creator_details(username=username.lower())
|
return await store_db.get_store_creator(username=username.lower())
|
||||||
|
|||||||
@@ -9,15 +9,26 @@ import logging
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, get_args, get_origin
|
||||||
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
|
from backend.blocks.llm import LlmModel
|
||||||
from backend.data.db import query_raw_with_schema
|
from backend.data.db import query_raw_with_schema
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _contains_type(annotation: Any, target: type) -> bool:
|
||||||
|
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
||||||
|
if annotation is target:
|
||||||
|
return True
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
if origin is None:
|
||||||
|
return False
|
||||||
|
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ContentItem:
|
class ContentItem:
|
||||||
"""Represents a piece of content to be embedded."""
|
"""Represents a piece of content to be embedded."""
|
||||||
@@ -188,45 +199,51 @@ class BlockHandler(ContentHandler):
|
|||||||
try:
|
try:
|
||||||
block_instance = block_cls()
|
block_instance = block_cls()
|
||||||
|
|
||||||
# Skip disabled blocks - they shouldn't be indexed
|
|
||||||
if block_instance.disabled:
|
if block_instance.disabled:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Build searchable text from block metadata
|
# Build searchable text from block metadata
|
||||||
parts = []
|
parts = []
|
||||||
if hasattr(block_instance, "name") and block_instance.name:
|
if block_instance.name:
|
||||||
parts.append(block_instance.name)
|
parts.append(block_instance.name)
|
||||||
if (
|
if block_instance.description:
|
||||||
hasattr(block_instance, "description")
|
|
||||||
and block_instance.description
|
|
||||||
):
|
|
||||||
parts.append(block_instance.description)
|
parts.append(block_instance.description)
|
||||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
if block_instance.categories:
|
||||||
# Convert BlockCategory enum to strings
|
|
||||||
parts.append(
|
parts.append(
|
||||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add input/output schema info
|
# Add input schema field descriptions
|
||||||
if hasattr(block_instance, "input_schema"):
|
block_input_fields = block_instance.input_schema.model_fields
|
||||||
schema = block_instance.input_schema
|
parts += [
|
||||||
if hasattr(schema, "model_json_schema"):
|
f"{field_name}: {field_info.description}"
|
||||||
schema_dict = schema.model_json_schema()
|
for field_name, field_info in block_input_fields.items()
|
||||||
if "properties" in schema_dict:
|
if field_info.description
|
||||||
for prop_name, prop_info in schema_dict[
|
]
|
||||||
"properties"
|
|
||||||
].items():
|
|
||||||
if "description" in prop_info:
|
|
||||||
parts.append(
|
|
||||||
f"{prop_name}: {prop_info['description']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
searchable_text = " ".join(parts)
|
searchable_text = " ".join(parts)
|
||||||
|
|
||||||
# Convert categories set of enums to list of strings for JSON serialization
|
|
||||||
categories = getattr(block_instance, "categories", set())
|
|
||||||
categories_list = (
|
categories_list = (
|
||||||
[cat.value for cat in categories] if categories else []
|
[cat.value for cat in block_instance.categories]
|
||||||
|
if block_instance.categories
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract provider names from credentials fields
|
||||||
|
credentials_info = (
|
||||||
|
block_instance.input_schema.get_credentials_fields_info()
|
||||||
|
)
|
||||||
|
is_integration = len(credentials_info) > 0
|
||||||
|
provider_names = [
|
||||||
|
provider.value.lower()
|
||||||
|
for info in credentials_info.values()
|
||||||
|
for provider in info.provider
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check if block has LlmModel field in input schema
|
||||||
|
has_llm_model_field = any(
|
||||||
|
_contains_type(field.annotation, LlmModel)
|
||||||
|
for field in block_instance.input_schema.model_fields.values()
|
||||||
)
|
)
|
||||||
|
|
||||||
items.append(
|
items.append(
|
||||||
@@ -235,8 +252,11 @@ class BlockHandler(ContentHandler):
|
|||||||
content_type=ContentType.BLOCK,
|
content_type=ContentType.BLOCK,
|
||||||
searchable_text=searchable_text,
|
searchable_text=searchable_text,
|
||||||
metadata={
|
metadata={
|
||||||
"name": getattr(block_instance, "name", ""),
|
"name": block_instance.name,
|
||||||
"categories": categories_list,
|
"categories": categories_list,
|
||||||
|
"providers": provider_names,
|
||||||
|
"has_llm_model_field": has_llm_model_field,
|
||||||
|
"is_integration": is_integration,
|
||||||
},
|
},
|
||||||
user_id=None, # Blocks are public
|
user_id=None, # Blocks are public
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -82,9 +82,10 @@ async def test_block_handler_get_missing_items(mocker):
|
|||||||
mock_block_instance.description = "Performs calculations"
|
mock_block_instance.description = "Performs calculations"
|
||||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||||
mock_block_instance.disabled = False
|
mock_block_instance.disabled = False
|
||||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
mock_field = MagicMock()
|
||||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
mock_field.description = "Math expression to evaluate"
|
||||||
}
|
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
|
||||||
|
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||||
mock_block_class.return_value = mock_block_instance
|
mock_block_class.return_value = mock_block_instance
|
||||||
|
|
||||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||||
@@ -309,19 +310,19 @@ async def test_content_handlers_registry():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_block_handler_handles_missing_attributes():
|
async def test_block_handler_handles_empty_attributes():
|
||||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
||||||
handler = BlockHandler()
|
handler = BlockHandler()
|
||||||
|
|
||||||
# Mock block with minimal attributes
|
# Mock block with empty values (all attributes exist but are falsy)
|
||||||
mock_block_class = MagicMock()
|
mock_block_class = MagicMock()
|
||||||
mock_block_instance = MagicMock()
|
mock_block_instance = MagicMock()
|
||||||
mock_block_instance.name = "Minimal Block"
|
mock_block_instance.name = "Minimal Block"
|
||||||
mock_block_instance.disabled = False
|
mock_block_instance.disabled = False
|
||||||
# No description, categories, or schema
|
mock_block_instance.description = ""
|
||||||
del mock_block_instance.description
|
mock_block_instance.categories = set()
|
||||||
del mock_block_instance.categories
|
mock_block_instance.input_schema.model_fields = {}
|
||||||
del mock_block_instance.input_schema
|
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||||
mock_block_class.return_value = mock_block_instance
|
mock_block_class.return_value = mock_block_instance
|
||||||
|
|
||||||
mock_blocks = {"block-minimal": mock_block_class}
|
mock_blocks = {"block-minimal": mock_block_class}
|
||||||
@@ -352,6 +353,8 @@ async def test_block_handler_skips_failed_blocks():
|
|||||||
good_instance.description = "Works fine"
|
good_instance.description = "Works fine"
|
||||||
good_instance.categories = []
|
good_instance.categories = []
|
||||||
good_instance.disabled = False
|
good_instance.disabled = False
|
||||||
|
good_instance.input_schema.model_fields = {}
|
||||||
|
good_instance.input_schema.get_credentials_fields_info.return_value = {}
|
||||||
good_block.return_value = good_instance
|
good_block.return_value = good_instance
|
||||||
|
|
||||||
bad_block = MagicMock()
|
bad_block = MagicMock()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
|
|||||||
mock_agents = [
|
mock_agents = [
|
||||||
prisma.models.StoreAgent(
|
prisma.models.StoreAgent(
|
||||||
listing_id="test-id",
|
listing_id="test-id",
|
||||||
storeListingVersionId="version123",
|
listing_version_id="version123",
|
||||||
slug="test-agent",
|
slug="test-agent",
|
||||||
agent_name="Test Agent",
|
agent_name="Test Agent",
|
||||||
agent_video=None,
|
agent_video=None,
|
||||||
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
|
|||||||
runs=10,
|
runs=10,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
versions=["1.0"],
|
versions=["1.0"],
|
||||||
agentGraphVersions=["1"],
|
graph_id="test-graph-id",
|
||||||
agentGraphId="test-graph-id",
|
graph_versions=["1"],
|
||||||
updated_at=datetime.now(),
|
updated_at=datetime.now(),
|
||||||
is_available=False,
|
is_available=False,
|
||||||
useForOnboarding=False,
|
use_for_onboarding=False,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
|
|||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_store_agent_details(mocker):
|
async def test_get_store_agent_details(mocker):
|
||||||
# Mock data
|
# Mock data - StoreAgent view already contains the active version data
|
||||||
mock_agent = prisma.models.StoreAgent(
|
mock_agent = prisma.models.StoreAgent(
|
||||||
listing_id="test-id",
|
listing_id="test-id",
|
||||||
storeListingVersionId="version123",
|
listing_version_id="version123",
|
||||||
slug="test-agent",
|
slug="test-agent",
|
||||||
agent_name="Test Agent",
|
agent_name="Test Agent",
|
||||||
agent_video="video.mp4",
|
agent_video="video.mp4",
|
||||||
@@ -85,102 +85,38 @@ async def test_get_store_agent_details(mocker):
|
|||||||
runs=10,
|
runs=10,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
versions=["1.0"],
|
versions=["1.0"],
|
||||||
agentGraphVersions=["1"],
|
graph_id="test-graph-id",
|
||||||
agentGraphId="test-graph-id",
|
graph_versions=["1"],
|
||||||
updated_at=datetime.now(),
|
|
||||||
is_available=False,
|
|
||||||
useForOnboarding=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock active version agent (what we want to return for active version)
|
|
||||||
mock_active_agent = prisma.models.StoreAgent(
|
|
||||||
listing_id="test-id",
|
|
||||||
storeListingVersionId="active-version-id",
|
|
||||||
slug="test-agent",
|
|
||||||
agent_name="Test Agent Active",
|
|
||||||
agent_video="active_video.mp4",
|
|
||||||
agent_image=["active_image.jpg"],
|
|
||||||
featured=False,
|
|
||||||
creator_username="creator",
|
|
||||||
creator_avatar="avatar.jpg",
|
|
||||||
sub_heading="Test heading active",
|
|
||||||
description="Test description active",
|
|
||||||
categories=["test"],
|
|
||||||
runs=15,
|
|
||||||
rating=4.8,
|
|
||||||
versions=["1.0", "2.0"],
|
|
||||||
agentGraphVersions=["1", "2"],
|
|
||||||
agentGraphId="test-graph-id-active",
|
|
||||||
updated_at=datetime.now(),
|
updated_at=datetime.now(),
|
||||||
is_available=True,
|
is_available=True,
|
||||||
useForOnboarding=False,
|
use_for_onboarding=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a mock StoreListing result
|
# Mock StoreAgent prisma call
|
||||||
mock_store_listing = mocker.MagicMock()
|
|
||||||
mock_store_listing.activeVersionId = "active-version-id"
|
|
||||||
mock_store_listing.hasApprovedVersion = True
|
|
||||||
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
|
||||||
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
|
||||||
|
|
||||||
# Mock StoreAgent prisma call - need to handle multiple calls
|
|
||||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||||
|
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||||
# Set up side_effect to return different results for different calls
|
|
||||||
def mock_find_first_side_effect(*args, **kwargs):
|
|
||||||
where_clause = kwargs.get("where", {})
|
|
||||||
if "storeListingVersionId" in where_clause:
|
|
||||||
# Second call for active version
|
|
||||||
return mock_active_agent
|
|
||||||
else:
|
|
||||||
# First call for initial lookup
|
|
||||||
return mock_agent
|
|
||||||
|
|
||||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
|
||||||
side_effect=mock_find_first_side_effect
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock Profile prisma call
|
|
||||||
mock_profile = mocker.MagicMock()
|
|
||||||
mock_profile.userId = "user-id-123"
|
|
||||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
|
||||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
|
||||||
return_value=mock_profile
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock StoreListing prisma call
|
|
||||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
|
||||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
|
||||||
return_value=mock_store_listing
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.get_store_agent_details("creator", "test-agent")
|
result = await db.get_store_agent_details("creator", "test-agent")
|
||||||
|
|
||||||
# Verify results - should use active version data
|
# Verify results - constructed from the StoreAgent view
|
||||||
assert result.slug == "test-agent"
|
assert result.slug == "test-agent"
|
||||||
assert result.agent_name == "Test Agent Active" # From active version
|
assert result.agent_name == "Test Agent"
|
||||||
assert result.active_version_id == "active-version-id"
|
assert result.active_version_id == "version123"
|
||||||
assert result.has_approved_version is True
|
assert result.has_approved_version is True
|
||||||
assert (
|
assert result.store_listing_version_id == "version123"
|
||||||
result.store_listing_version_id == "active-version-id"
|
assert result.graph_id == "test-graph-id"
|
||||||
) # Should be active version ID
|
assert result.runs == 10
|
||||||
|
assert result.rating == 4.5
|
||||||
|
|
||||||
# Verify mocks called correctly - now expecting 2 calls
|
# Verify single StoreAgent lookup
|
||||||
assert mock_store_agent.return_value.find_first.call_count == 2
|
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||||
|
|
||||||
# Check the specific calls
|
|
||||||
calls = mock_store_agent.return_value.find_first.call_args_list
|
|
||||||
assert calls[0] == mocker.call(
|
|
||||||
where={"creator_username": "creator", "slug": "test-agent"}
|
where={"creator_username": "creator", "slug": "test-agent"}
|
||||||
)
|
)
|
||||||
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
|
||||||
|
|
||||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_store_creator_details(mocker):
|
async def test_get_store_creator(mocker):
|
||||||
# Mock data
|
# Mock data
|
||||||
mock_creator_data = prisma.models.Creator(
|
mock_creator_data = prisma.models.Creator(
|
||||||
name="Test Creator",
|
name="Test Creator",
|
||||||
@@ -202,7 +138,7 @@ async def test_get_store_creator_details(mocker):
|
|||||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.get_store_creator_details("creator")
|
result = await db.get_store_creator("creator")
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert result.username == "creator"
|
assert result.username == "creator"
|
||||||
@@ -218,61 +154,110 @@ async def test_get_store_creator_details(mocker):
|
|||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_create_store_submission(mocker):
|
async def test_create_store_submission(mocker):
|
||||||
# Mock data
|
now = datetime.now()
|
||||||
|
|
||||||
|
# Mock agent graph (with no pending submissions) and user with profile
|
||||||
|
mock_profile = prisma.models.Profile(
|
||||||
|
id="profile-id",
|
||||||
|
userId="user-id",
|
||||||
|
name="Test User",
|
||||||
|
username="testuser",
|
||||||
|
description="Test",
|
||||||
|
isFeatured=False,
|
||||||
|
links=[],
|
||||||
|
createdAt=now,
|
||||||
|
updatedAt=now,
|
||||||
|
)
|
||||||
|
mock_user = prisma.models.User(
|
||||||
|
id="user-id",
|
||||||
|
email="test@example.com",
|
||||||
|
createdAt=now,
|
||||||
|
updatedAt=now,
|
||||||
|
Profile=[mock_profile],
|
||||||
|
emailVerified=True,
|
||||||
|
metadata="{}", # type: ignore[reportArgumentType]
|
||||||
|
integrations="",
|
||||||
|
maxEmailsPerDay=1,
|
||||||
|
notifyOnAgentRun=True,
|
||||||
|
notifyOnZeroBalance=True,
|
||||||
|
notifyOnLowBalance=True,
|
||||||
|
notifyOnBlockExecutionFailed=True,
|
||||||
|
notifyOnContinuousAgentError=True,
|
||||||
|
notifyOnDailySummary=True,
|
||||||
|
notifyOnWeeklySummary=True,
|
||||||
|
notifyOnMonthlySummary=True,
|
||||||
|
notifyOnAgentApproved=True,
|
||||||
|
notifyOnAgentRejected=True,
|
||||||
|
timezone="Europe/Delft",
|
||||||
|
)
|
||||||
mock_agent = prisma.models.AgentGraph(
|
mock_agent = prisma.models.AgentGraph(
|
||||||
id="agent-id",
|
id="agent-id",
|
||||||
version=1,
|
version=1,
|
||||||
userId="user-id",
|
userId="user-id",
|
||||||
createdAt=datetime.now(),
|
createdAt=now,
|
||||||
isActive=True,
|
isActive=True,
|
||||||
|
StoreListingVersions=[],
|
||||||
|
User=mock_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_listing = prisma.models.StoreListing(
|
# Mock the created StoreListingVersion (returned by create)
|
||||||
|
mock_store_listing_obj = prisma.models.StoreListing(
|
||||||
id="listing-id",
|
id="listing-id",
|
||||||
createdAt=datetime.now(),
|
createdAt=now,
|
||||||
updatedAt=datetime.now(),
|
updatedAt=now,
|
||||||
isDeleted=False,
|
isDeleted=False,
|
||||||
hasApprovedVersion=False,
|
hasApprovedVersion=False,
|
||||||
slug="test-agent",
|
slug="test-agent",
|
||||||
agentGraphId="agent-id",
|
agentGraphId="agent-id",
|
||||||
agentGraphVersion=1,
|
|
||||||
owningUserId="user-id",
|
owningUserId="user-id",
|
||||||
Versions=[
|
|
||||||
prisma.models.StoreListingVersion(
|
|
||||||
id="version-id",
|
|
||||||
agentGraphId="agent-id",
|
|
||||||
agentGraphVersion=1,
|
|
||||||
name="Test Agent",
|
|
||||||
description="Test description",
|
|
||||||
createdAt=datetime.now(),
|
|
||||||
updatedAt=datetime.now(),
|
|
||||||
subHeading="Test heading",
|
|
||||||
imageUrls=["image.jpg"],
|
|
||||||
categories=["test"],
|
|
||||||
isFeatured=False,
|
|
||||||
isDeleted=False,
|
|
||||||
version=1,
|
|
||||||
storeListingId="listing-id",
|
|
||||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
|
||||||
isAvailable=True,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
useForOnboarding=False,
|
useForOnboarding=False,
|
||||||
)
|
)
|
||||||
|
mock_version = prisma.models.StoreListingVersion(
|
||||||
|
id="version-id",
|
||||||
|
agentGraphId="agent-id",
|
||||||
|
agentGraphVersion=1,
|
||||||
|
name="Test Agent",
|
||||||
|
description="Test description",
|
||||||
|
createdAt=now,
|
||||||
|
updatedAt=now,
|
||||||
|
subHeading="",
|
||||||
|
imageUrls=[],
|
||||||
|
categories=[],
|
||||||
|
isFeatured=False,
|
||||||
|
isDeleted=False,
|
||||||
|
version=1,
|
||||||
|
storeListingId="listing-id",
|
||||||
|
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||||
|
isAvailable=True,
|
||||||
|
submittedAt=now,
|
||||||
|
StoreListing=mock_store_listing_obj,
|
||||||
|
)
|
||||||
|
|
||||||
# Mock prisma calls
|
# Mock prisma calls
|
||||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||||
|
|
||||||
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
|
# Mock transaction context manager
|
||||||
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
|
mock_tx = mocker.MagicMock()
|
||||||
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
|
mocker.patch(
|
||||||
|
"backend.api.features.store.db.transaction",
|
||||||
|
return_value=mocker.AsyncMock(
|
||||||
|
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
||||||
|
__aexit__=mocker.AsyncMock(return_value=False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
||||||
|
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
||||||
|
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.create_store_submission(
|
result = await db.create_store_submission(
|
||||||
user_id="user-id",
|
user_id="user-id",
|
||||||
agent_id="agent-id",
|
graph_id="agent-id",
|
||||||
agent_version=1,
|
graph_version=1,
|
||||||
slug="test-agent",
|
slug="test-agent",
|
||||||
name="Test Agent",
|
name="Test Agent",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
@@ -281,11 +266,11 @@ async def test_create_store_submission(mocker):
|
|||||||
# Verify results
|
# Verify results
|
||||||
assert result.name == "Test Agent"
|
assert result.name == "Test Agent"
|
||||||
assert result.description == "Test description"
|
assert result.description == "Test description"
|
||||||
assert result.store_listing_version_id == "version-id"
|
assert result.listing_version_id == "version-id"
|
||||||
|
|
||||||
# Verify mocks called correctly
|
# Verify mocks called correctly
|
||||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||||
mock_store_listing.return_value.create.assert_called_once()
|
mock_slv.return_value.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
@@ -318,7 +303,6 @@ async def test_update_profile(mocker):
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
links=["link1"],
|
links=["link1"],
|
||||||
avatar_url="avatar.jpg",
|
avatar_url="avatar.jpg",
|
||||||
is_featured=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
@@ -389,7 +373,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
|
|||||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||||
category="AI'; DELETE FROM StoreAgent; --",
|
category="AI'; DELETE FROM StoreAgent; --",
|
||||||
featured=True,
|
featured=True,
|
||||||
sorted_by="rating",
|
sorted_by=db.StoreAgentsSortOptions.RATING,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=20,
|
page_size=20,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,12 +57,6 @@ class StoreError(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AgentNotFoundError(NotFoundError):
|
|
||||||
"""Raised when an agent is not found"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CreatorNotFoundError(NotFoundError):
|
class CreatorNotFoundError(NotFoundError):
|
||||||
"""Raised when a creator is not found"""
|
"""Raised when a creator is not found"""
|
||||||
|
|
||||||
|
|||||||
@@ -568,7 +568,7 @@ async def hybrid_search(
|
|||||||
SELECT uce."contentId" as "storeListingVersionId"
|
SELECT uce."contentId" as "storeListingVersionId"
|
||||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||||
ON uce."contentId" = sa."storeListingVersionId"
|
ON uce."contentId" = sa.listing_version_id
|
||||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||||
AND uce."userId" IS NULL
|
AND uce."userId" IS NULL
|
||||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||||
@@ -582,7 +582,7 @@ async def hybrid_search(
|
|||||||
SELECT uce."contentId", uce.embedding
|
SELECT uce."contentId", uce.embedding
|
||||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||||
ON uce."contentId" = sa."storeListingVersionId"
|
ON uce."contentId" = sa.listing_version_id
|
||||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||||
AND uce."userId" IS NULL
|
AND uce."userId" IS NULL
|
||||||
AND {where_clause}
|
AND {where_clause}
|
||||||
@@ -605,7 +605,7 @@ async def hybrid_search(
|
|||||||
sa.featured,
|
sa.featured,
|
||||||
sa.is_available,
|
sa.is_available,
|
||||||
sa.updated_at,
|
sa.updated_at,
|
||||||
sa."agentGraphId",
|
sa.graph_id,
|
||||||
-- Searchable text for BM25 reranking
|
-- Searchable text for BM25 reranking
|
||||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||||
-- Semantic score
|
-- Semantic score
|
||||||
@@ -627,9 +627,9 @@ async def hybrid_search(
|
|||||||
sa.runs as popularity_raw
|
sa.runs as popularity_raw
|
||||||
FROM candidates c
|
FROM candidates c
|
||||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
ON c."storeListingVersionId" = sa.listing_version_id
|
||||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||||
ON sa."storeListingVersionId" = uce."contentId"
|
ON sa.listing_version_id = uce."contentId"
|
||||||
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||||
),
|
),
|
||||||
max_vals AS (
|
max_vals AS (
|
||||||
@@ -665,7 +665,7 @@ async def hybrid_search(
|
|||||||
featured,
|
featured,
|
||||||
is_available,
|
is_available,
|
||||||
updated_at,
|
updated_at,
|
||||||
"agentGraphId",
|
graph_id,
|
||||||
searchable_text,
|
searchable_text,
|
||||||
semantic_score,
|
semantic_score,
|
||||||
lexical_score,
|
lexical_score,
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from typing import List
|
from typing import TYPE_CHECKING, List, Self
|
||||||
|
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import prisma.models
|
||||||
|
|
||||||
|
|
||||||
class ChangelogEntry(pydantic.BaseModel):
|
class ChangelogEntry(pydantic.BaseModel):
|
||||||
version: str
|
version: str
|
||||||
@@ -13,9 +16,9 @@ class ChangelogEntry(pydantic.BaseModel):
|
|||||||
date: datetime.datetime
|
date: datetime.datetime
|
||||||
|
|
||||||
|
|
||||||
class MyAgent(pydantic.BaseModel):
|
class MyUnpublishedAgent(pydantic.BaseModel):
|
||||||
agent_id: str
|
graph_id: str
|
||||||
agent_version: int
|
graph_version: int
|
||||||
agent_name: str
|
agent_name: str
|
||||||
agent_image: str | None = None
|
agent_image: str | None = None
|
||||||
description: str
|
description: str
|
||||||
@@ -23,8 +26,8 @@ class MyAgent(pydantic.BaseModel):
|
|||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class MyAgentsResponse(pydantic.BaseModel):
|
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
|
||||||
agents: list[MyAgent]
|
agents: list[MyUnpublishedAgent]
|
||||||
pagination: Pagination
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
@@ -40,6 +43,21 @@ class StoreAgent(pydantic.BaseModel):
|
|||||||
rating: float
|
rating: float
|
||||||
agent_graph_id: str
|
agent_graph_id: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
|
||||||
|
return cls(
|
||||||
|
slug=agent.slug,
|
||||||
|
agent_name=agent.agent_name,
|
||||||
|
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||||
|
creator=agent.creator_username or "Needs Profile",
|
||||||
|
creator_avatar=agent.creator_avatar or "",
|
||||||
|
sub_heading=agent.sub_heading,
|
||||||
|
description=agent.description,
|
||||||
|
runs=agent.runs,
|
||||||
|
rating=agent.rating,
|
||||||
|
agent_graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StoreAgentsResponse(pydantic.BaseModel):
|
class StoreAgentsResponse(pydantic.BaseModel):
|
||||||
agents: list[StoreAgent]
|
agents: list[StoreAgent]
|
||||||
@@ -62,81 +80,192 @@ class StoreAgentDetails(pydantic.BaseModel):
|
|||||||
runs: int
|
runs: int
|
||||||
rating: float
|
rating: float
|
||||||
versions: list[str]
|
versions: list[str]
|
||||||
agentGraphVersions: list[str]
|
graph_id: str
|
||||||
agentGraphId: str
|
graph_versions: list[str]
|
||||||
last_updated: datetime.datetime
|
last_updated: datetime.datetime
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
|
||||||
active_version_id: str | None = None
|
active_version_id: str
|
||||||
has_approved_version: bool = False
|
has_approved_version: bool
|
||||||
|
|
||||||
# Optional changelog data when include_changelog=True
|
# Optional changelog data when include_changelog=True
|
||||||
changelog: list[ChangelogEntry] | None = None
|
changelog: list[ChangelogEntry] | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
class Creator(pydantic.BaseModel):
|
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
|
||||||
name: str
|
return cls(
|
||||||
username: str
|
store_listing_version_id=agent.listing_version_id,
|
||||||
description: str
|
slug=agent.slug,
|
||||||
avatar_url: str
|
agent_name=agent.agent_name,
|
||||||
num_agents: int
|
agent_video=agent.agent_video or "",
|
||||||
agent_rating: float
|
agent_output_demo=agent.agent_output_demo or "",
|
||||||
agent_runs: int
|
agent_image=agent.agent_image,
|
||||||
is_featured: bool
|
creator=agent.creator_username or "",
|
||||||
|
creator_avatar=agent.creator_avatar or "",
|
||||||
|
sub_heading=agent.sub_heading,
|
||||||
class CreatorsResponse(pydantic.BaseModel):
|
description=agent.description,
|
||||||
creators: List[Creator]
|
categories=agent.categories,
|
||||||
pagination: Pagination
|
runs=agent.runs,
|
||||||
|
rating=agent.rating,
|
||||||
|
versions=agent.versions,
|
||||||
class CreatorDetails(pydantic.BaseModel):
|
graph_id=agent.graph_id,
|
||||||
name: str
|
graph_versions=agent.graph_versions,
|
||||||
username: str
|
last_updated=agent.updated_at,
|
||||||
description: str
|
recommended_schedule_cron=agent.recommended_schedule_cron,
|
||||||
links: list[str]
|
active_version_id=agent.listing_version_id,
|
||||||
avatar_url: str
|
has_approved_version=True, # StoreAgent view only has approved agents
|
||||||
agent_rating: float
|
)
|
||||||
agent_runs: int
|
|
||||||
top_categories: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class Profile(pydantic.BaseModel):
|
class Profile(pydantic.BaseModel):
|
||||||
name: str
|
"""Marketplace user profile (only attributes that the user can update)"""
|
||||||
|
|
||||||
username: str
|
username: str
|
||||||
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
avatar_url: str | None
|
||||||
links: list[str]
|
links: list[str]
|
||||||
avatar_url: str
|
|
||||||
is_featured: bool = False
|
|
||||||
|
class ProfileDetails(Profile):
|
||||||
|
"""Marketplace user profile (including read-only fields)"""
|
||||||
|
|
||||||
|
is_featured: bool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
|
||||||
|
return cls(
|
||||||
|
name=profile.name,
|
||||||
|
username=profile.username,
|
||||||
|
avatar_url=profile.avatarUrl,
|
||||||
|
description=profile.description,
|
||||||
|
links=profile.links,
|
||||||
|
is_featured=profile.isFeatured,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreatorDetails(ProfileDetails):
|
||||||
|
"""Marketplace creator profile details, including aggregated stats"""
|
||||||
|
|
||||||
|
num_agents: int
|
||||||
|
agent_runs: int
|
||||||
|
agent_rating: float
|
||||||
|
top_categories: list[str]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
|
||||||
|
return cls(
|
||||||
|
name=creator.name,
|
||||||
|
username=creator.username,
|
||||||
|
avatar_url=creator.avatar_url,
|
||||||
|
description=creator.description,
|
||||||
|
links=creator.links,
|
||||||
|
is_featured=creator.is_featured,
|
||||||
|
num_agents=creator.num_agents,
|
||||||
|
agent_runs=creator.agent_runs,
|
||||||
|
agent_rating=creator.agent_rating,
|
||||||
|
top_categories=creator.top_categories,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreatorsResponse(pydantic.BaseModel):
|
||||||
|
creators: List[CreatorDetails]
|
||||||
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
class StoreSubmission(pydantic.BaseModel):
|
class StoreSubmission(pydantic.BaseModel):
|
||||||
|
# From StoreListing:
|
||||||
listing_id: str
|
listing_id: str
|
||||||
agent_id: str
|
user_id: str
|
||||||
agent_version: int
|
slug: str
|
||||||
|
|
||||||
|
# From StoreListingVersion:
|
||||||
|
listing_version_id: str
|
||||||
|
listing_version: int
|
||||||
|
graph_id: str
|
||||||
|
graph_version: int
|
||||||
name: str
|
name: str
|
||||||
sub_heading: str
|
sub_heading: str
|
||||||
slug: str
|
|
||||||
description: str
|
description: str
|
||||||
instructions: str | None = None
|
instructions: str | None
|
||||||
|
categories: list[str]
|
||||||
image_urls: list[str]
|
image_urls: list[str]
|
||||||
date_submitted: datetime.datetime
|
video_url: str | None
|
||||||
status: prisma.enums.SubmissionStatus
|
agent_output_demo_url: str | None
|
||||||
runs: int
|
|
||||||
rating: float
|
|
||||||
store_listing_version_id: str | None = None
|
|
||||||
version: int | None = None # Actual version number from the database
|
|
||||||
|
|
||||||
|
submitted_at: datetime.datetime | None
|
||||||
|
changes_summary: str | None
|
||||||
|
status: prisma.enums.SubmissionStatus
|
||||||
|
reviewed_at: datetime.datetime | None = None
|
||||||
reviewer_id: str | None = None
|
reviewer_id: str | None = None
|
||||||
review_comments: str | None = None # External comments visible to creator
|
review_comments: str | None = None # External comments visible to creator
|
||||||
internal_comments: str | None = None # Private notes for admin use only
|
|
||||||
reviewed_at: datetime.datetime | None = None
|
|
||||||
changes_summary: str | None = None
|
|
||||||
|
|
||||||
# Additional fields for editing
|
# Aggregated from AgentGraphExecutions and StoreListingReviews:
|
||||||
video_url: str | None = None
|
run_count: int = 0
|
||||||
agent_output_demo_url: str | None = None
|
review_count: int = 0
|
||||||
categories: list[str] = []
|
review_avg_rating: float = 0.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||||
|
"""Construct from the StoreSubmission Prisma view."""
|
||||||
|
return cls(
|
||||||
|
listing_id=_sub.listing_id,
|
||||||
|
user_id=_sub.user_id,
|
||||||
|
slug=_sub.slug,
|
||||||
|
listing_version_id=_sub.listing_version_id,
|
||||||
|
listing_version=_sub.listing_version,
|
||||||
|
graph_id=_sub.graph_id,
|
||||||
|
graph_version=_sub.graph_version,
|
||||||
|
name=_sub.name,
|
||||||
|
sub_heading=_sub.sub_heading,
|
||||||
|
description=_sub.description,
|
||||||
|
instructions=_sub.instructions,
|
||||||
|
categories=_sub.categories,
|
||||||
|
image_urls=_sub.image_urls,
|
||||||
|
video_url=_sub.video_url,
|
||||||
|
agent_output_demo_url=_sub.agent_output_demo_url,
|
||||||
|
submitted_at=_sub.submitted_at,
|
||||||
|
changes_summary=_sub.changes_summary,
|
||||||
|
status=_sub.status,
|
||||||
|
reviewed_at=_sub.reviewed_at,
|
||||||
|
reviewer_id=_sub.reviewer_id,
|
||||||
|
review_comments=_sub.review_comments,
|
||||||
|
run_count=_sub.run_count,
|
||||||
|
review_count=_sub.review_count,
|
||||||
|
review_avg_rating=_sub.review_avg_rating,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||||
|
"""
|
||||||
|
Construct from the StoreListingVersion Prisma model (with StoreListing included)
|
||||||
|
"""
|
||||||
|
if not (_l := _lv.StoreListing):
|
||||||
|
raise ValueError("StoreListingVersion must have included StoreListing")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
listing_id=_l.id,
|
||||||
|
user_id=_l.owningUserId,
|
||||||
|
slug=_l.slug,
|
||||||
|
listing_version_id=_lv.id,
|
||||||
|
listing_version=_lv.version,
|
||||||
|
graph_id=_lv.agentGraphId,
|
||||||
|
graph_version=_lv.agentGraphVersion,
|
||||||
|
name=_lv.name,
|
||||||
|
sub_heading=_lv.subHeading,
|
||||||
|
description=_lv.description,
|
||||||
|
instructions=_lv.instructions,
|
||||||
|
categories=_lv.categories,
|
||||||
|
image_urls=_lv.imageUrls,
|
||||||
|
video_url=_lv.videoUrl,
|
||||||
|
agent_output_demo_url=_lv.agentOutputDemoUrl,
|
||||||
|
submitted_at=_lv.submittedAt,
|
||||||
|
changes_summary=_lv.changesSummary,
|
||||||
|
status=_lv.submissionStatus,
|
||||||
|
reviewed_at=_lv.reviewedAt,
|
||||||
|
reviewer_id=_lv.reviewerId,
|
||||||
|
review_comments=_lv.reviewComments,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||||
@@ -144,33 +273,12 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
|
|||||||
pagination: Pagination
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
class StoreListingWithVersions(pydantic.BaseModel):
|
|
||||||
"""A store listing with its version history"""
|
|
||||||
|
|
||||||
listing_id: str
|
|
||||||
slug: str
|
|
||||||
agent_id: str
|
|
||||||
agent_version: int
|
|
||||||
active_version_id: str | None = None
|
|
||||||
has_approved_version: bool = False
|
|
||||||
creator_email: str | None = None
|
|
||||||
latest_version: StoreSubmission | None = None
|
|
||||||
versions: list[StoreSubmission] = []
|
|
||||||
|
|
||||||
|
|
||||||
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
|
||||||
"""Response model for listings with version history"""
|
|
||||||
|
|
||||||
listings: list[StoreListingWithVersions]
|
|
||||||
pagination: Pagination
|
|
||||||
|
|
||||||
|
|
||||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||||
agent_id: str = pydantic.Field(
|
graph_id: str = pydantic.Field(
|
||||||
..., min_length=1, description="Agent ID cannot be empty"
|
..., min_length=1, description="Graph ID cannot be empty"
|
||||||
)
|
)
|
||||||
agent_version: int = pydantic.Field(
|
graph_version: int = pydantic.Field(
|
||||||
..., gt=0, description="Agent version must be greater than 0"
|
..., gt=0, description="Graph version must be greater than 0"
|
||||||
)
|
)
|
||||||
slug: str
|
slug: str
|
||||||
name: str
|
name: str
|
||||||
@@ -198,12 +306,42 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
|
|||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ProfileDetails(pydantic.BaseModel):
|
class StoreSubmissionAdminView(StoreSubmission):
|
||||||
name: str
|
internal_comments: str | None # Private admin notes
|
||||||
username: str
|
|
||||||
description: str
|
@classmethod
|
||||||
links: list[str]
|
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
||||||
avatar_url: str | None = None
|
return cls(
|
||||||
|
**StoreSubmission.from_db(_sub).model_dump(),
|
||||||
|
internal_comments=_sub.internal_comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
||||||
|
return cls(
|
||||||
|
**StoreSubmission.from_listing_version(_lv).model_dump(),
|
||||||
|
internal_comments=_lv.internalComments,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
|
||||||
|
"""A store listing with its version history"""
|
||||||
|
|
||||||
|
listing_id: str
|
||||||
|
graph_id: str
|
||||||
|
slug: str
|
||||||
|
active_listing_version_id: str | None = None
|
||||||
|
has_approved_version: bool = False
|
||||||
|
creator_email: str | None = None
|
||||||
|
latest_version: StoreSubmissionAdminView | None = None
|
||||||
|
versions: list[StoreSubmissionAdminView] = []
|
||||||
|
|
||||||
|
|
||||||
|
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
|
||||||
|
"""Response model for listings with version history"""
|
||||||
|
|
||||||
|
listings: list[StoreListingWithVersionsAdminView]
|
||||||
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
class StoreReview(pydantic.BaseModel):
|
class StoreReview(pydantic.BaseModel):
|
||||||
|
|||||||
@@ -1,203 +0,0 @@
|
|||||||
import datetime
|
|
||||||
|
|
||||||
import prisma.enums
|
|
||||||
|
|
||||||
from . import model as store_model
|
|
||||||
|
|
||||||
|
|
||||||
def test_pagination():
|
|
||||||
pagination = store_model.Pagination(
|
|
||||||
total_items=100, total_pages=5, current_page=2, page_size=20
|
|
||||||
)
|
|
||||||
assert pagination.total_items == 100
|
|
||||||
assert pagination.total_pages == 5
|
|
||||||
assert pagination.current_page == 2
|
|
||||||
assert pagination.page_size == 20
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_agent():
|
|
||||||
agent = store_model.StoreAgent(
|
|
||||||
slug="test-agent",
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_image="test.jpg",
|
|
||||||
creator="creator1",
|
|
||||||
creator_avatar="avatar.jpg",
|
|
||||||
sub_heading="Test subheading",
|
|
||||||
description="Test description",
|
|
||||||
runs=50,
|
|
||||||
rating=4.5,
|
|
||||||
agent_graph_id="test-graph-id",
|
|
||||||
)
|
|
||||||
assert agent.slug == "test-agent"
|
|
||||||
assert agent.agent_name == "Test Agent"
|
|
||||||
assert agent.runs == 50
|
|
||||||
assert agent.rating == 4.5
|
|
||||||
assert agent.agent_graph_id == "test-graph-id"
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_agents_response():
|
|
||||||
response = store_model.StoreAgentsResponse(
|
|
||||||
agents=[
|
|
||||||
store_model.StoreAgent(
|
|
||||||
slug="test-agent",
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_image="test.jpg",
|
|
||||||
creator="creator1",
|
|
||||||
creator_avatar="avatar.jpg",
|
|
||||||
sub_heading="Test subheading",
|
|
||||||
description="Test description",
|
|
||||||
runs=50,
|
|
||||||
rating=4.5,
|
|
||||||
agent_graph_id="test-graph-id",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
pagination=store_model.Pagination(
|
|
||||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
|
||||||
),
|
|
||||||
)
|
|
||||||
assert len(response.agents) == 1
|
|
||||||
assert response.pagination.total_items == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_agent_details():
|
|
||||||
details = store_model.StoreAgentDetails(
|
|
||||||
store_listing_version_id="version123",
|
|
||||||
slug="test-agent",
|
|
||||||
agent_name="Test Agent",
|
|
||||||
agent_video="video.mp4",
|
|
||||||
agent_output_demo="demo.mp4",
|
|
||||||
agent_image=["image1.jpg", "image2.jpg"],
|
|
||||||
creator="creator1",
|
|
||||||
creator_avatar="avatar.jpg",
|
|
||||||
sub_heading="Test subheading",
|
|
||||||
description="Test description",
|
|
||||||
categories=["cat1", "cat2"],
|
|
||||||
runs=50,
|
|
||||||
rating=4.5,
|
|
||||||
versions=["1.0", "2.0"],
|
|
||||||
agentGraphVersions=["1", "2"],
|
|
||||||
agentGraphId="test-graph-id",
|
|
||||||
last_updated=datetime.datetime.now(),
|
|
||||||
)
|
|
||||||
assert details.slug == "test-agent"
|
|
||||||
assert len(details.agent_image) == 2
|
|
||||||
assert len(details.categories) == 2
|
|
||||||
assert len(details.versions) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_creator():
|
|
||||||
creator = store_model.Creator(
|
|
||||||
agent_rating=4.8,
|
|
||||||
agent_runs=1000,
|
|
||||||
name="Test Creator",
|
|
||||||
username="creator1",
|
|
||||||
description="Test description",
|
|
||||||
avatar_url="avatar.jpg",
|
|
||||||
num_agents=5,
|
|
||||||
is_featured=False,
|
|
||||||
)
|
|
||||||
assert creator.name == "Test Creator"
|
|
||||||
assert creator.num_agents == 5
|
|
||||||
|
|
||||||
|
|
||||||
def test_creators_response():
|
|
||||||
response = store_model.CreatorsResponse(
|
|
||||||
creators=[
|
|
||||||
store_model.Creator(
|
|
||||||
agent_rating=4.8,
|
|
||||||
agent_runs=1000,
|
|
||||||
name="Test Creator",
|
|
||||||
username="creator1",
|
|
||||||
description="Test description",
|
|
||||||
avatar_url="avatar.jpg",
|
|
||||||
num_agents=5,
|
|
||||||
is_featured=False,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
pagination=store_model.Pagination(
|
|
||||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
|
||||||
),
|
|
||||||
)
|
|
||||||
assert len(response.creators) == 1
|
|
||||||
assert response.pagination.total_items == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_creator_details():
|
|
||||||
details = store_model.CreatorDetails(
|
|
||||||
name="Test Creator",
|
|
||||||
username="creator1",
|
|
||||||
description="Test description",
|
|
||||||
links=["link1.com", "link2.com"],
|
|
||||||
avatar_url="avatar.jpg",
|
|
||||||
agent_rating=4.8,
|
|
||||||
agent_runs=1000,
|
|
||||||
top_categories=["cat1", "cat2"],
|
|
||||||
)
|
|
||||||
assert details.name == "Test Creator"
|
|
||||||
assert len(details.links) == 2
|
|
||||||
assert details.agent_rating == 4.8
|
|
||||||
assert len(details.top_categories) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_submission():
|
|
||||||
submission = store_model.StoreSubmission(
|
|
||||||
listing_id="listing123",
|
|
||||||
agent_id="agent123",
|
|
||||||
agent_version=1,
|
|
||||||
sub_heading="Test subheading",
|
|
||||||
name="Test Agent",
|
|
||||||
slug="test-agent",
|
|
||||||
description="Test description",
|
|
||||||
image_urls=["image1.jpg", "image2.jpg"],
|
|
||||||
date_submitted=datetime.datetime(2023, 1, 1),
|
|
||||||
status=prisma.enums.SubmissionStatus.PENDING,
|
|
||||||
runs=50,
|
|
||||||
rating=4.5,
|
|
||||||
)
|
|
||||||
assert submission.name == "Test Agent"
|
|
||||||
assert len(submission.image_urls) == 2
|
|
||||||
assert submission.status == prisma.enums.SubmissionStatus.PENDING
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_submissions_response():
|
|
||||||
response = store_model.StoreSubmissionsResponse(
|
|
||||||
submissions=[
|
|
||||||
store_model.StoreSubmission(
|
|
||||||
listing_id="listing123",
|
|
||||||
agent_id="agent123",
|
|
||||||
agent_version=1,
|
|
||||||
sub_heading="Test subheading",
|
|
||||||
name="Test Agent",
|
|
||||||
slug="test-agent",
|
|
||||||
description="Test description",
|
|
||||||
image_urls=["image1.jpg"],
|
|
||||||
date_submitted=datetime.datetime(2023, 1, 1),
|
|
||||||
status=prisma.enums.SubmissionStatus.PENDING,
|
|
||||||
runs=50,
|
|
||||||
rating=4.5,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
pagination=store_model.Pagination(
|
|
||||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
|
||||||
),
|
|
||||||
)
|
|
||||||
assert len(response.submissions) == 1
|
|
||||||
assert response.pagination.total_items == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_submission_request():
|
|
||||||
request = store_model.StoreSubmissionRequest(
|
|
||||||
agent_id="agent123",
|
|
||||||
agent_version=1,
|
|
||||||
slug="test-agent",
|
|
||||||
name="Test Agent",
|
|
||||||
sub_heading="Test subheading",
|
|
||||||
video_url="video.mp4",
|
|
||||||
image_urls=["image1.jpg", "image2.jpg"],
|
|
||||||
description="Test description",
|
|
||||||
categories=["cat1", "cat2"],
|
|
||||||
)
|
|
||||||
assert request.agent_id == "agent123"
|
|
||||||
assert request.agent_version == 1
|
|
||||||
assert len(request.image_urls) == 2
|
|
||||||
assert len(request.categories) == 2
|
|
||||||
@@ -1,16 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
import typing
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import autogpt_libs.auth
|
import autogpt_libs.auth
|
||||||
import fastapi
|
import fastapi
|
||||||
import fastapi.responses
|
import fastapi.responses
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
|
from fastapi import Query, Security
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
import backend.util.json
|
import backend.util.json
|
||||||
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
from . import cache as store_cache
|
from . import cache as store_cache
|
||||||
@@ -34,22 +35,15 @@ router = fastapi.APIRouter()
|
|||||||
"/profile",
|
"/profile",
|
||||||
summary="Get user profile",
|
summary="Get user profile",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
response_model=store_model.ProfileDetails,
|
|
||||||
)
|
)
|
||||||
async def get_profile(
|
async def get_profile(
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
):
|
) -> store_model.ProfileDetails:
|
||||||
"""
|
"""Get the profile details for the authenticated user."""
|
||||||
Get the profile details for the authenticated user.
|
|
||||||
Cached for 1 hour per user.
|
|
||||||
"""
|
|
||||||
profile = await store_db.get_user_profile(user_id)
|
profile = await store_db.get_user_profile(user_id)
|
||||||
if profile is None:
|
if profile is None:
|
||||||
return fastapi.responses.JSONResponse(
|
raise NotFoundError("User does not have a profile yet")
|
||||||
status_code=404,
|
|
||||||
content={"detail": "Profile not found"},
|
|
||||||
)
|
|
||||||
return profile
|
return profile
|
||||||
|
|
||||||
|
|
||||||
@@ -57,98 +51,17 @@ async def get_profile(
|
|||||||
"/profile",
|
"/profile",
|
||||||
summary="Update user profile",
|
summary="Update user profile",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
response_model=store_model.CreatorDetails,
|
|
||||||
)
|
)
|
||||||
async def update_or_create_profile(
|
async def update_or_create_profile(
|
||||||
profile: store_model.Profile,
|
profile: store_model.Profile,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
):
|
) -> store_model.ProfileDetails:
|
||||||
"""
|
"""Update the store profile for the authenticated user."""
|
||||||
Update the store profile for the authenticated user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
profile (Profile): The updated profile details
|
|
||||||
user_id (str): ID of the authenticated user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CreatorDetails: The updated profile
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If there is an error updating the profile
|
|
||||||
"""
|
|
||||||
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
||||||
return updated_profile
|
return updated_profile
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
|
||||||
############### Agent Endpoints ##############
|
|
||||||
##############################################
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/agents",
|
|
||||||
summary="List store agents",
|
|
||||||
tags=["store", "public"],
|
|
||||||
response_model=store_model.StoreAgentsResponse,
|
|
||||||
)
|
|
||||||
async def get_agents(
|
|
||||||
featured: bool = False,
|
|
||||||
creator: str | None = None,
|
|
||||||
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
|
||||||
search_query: str | None = None,
|
|
||||||
category: str | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 20,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
|
||||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
|
||||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
|
||||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
|
||||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
|
||||||
page (int, optional): Page number for pagination. Defaults to 1.
|
|
||||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If page or page_size are less than 1
|
|
||||||
|
|
||||||
Used for:
|
|
||||||
- Home Page Featured Agents
|
|
||||||
- Home Page Top Agents
|
|
||||||
- Search Results
|
|
||||||
- Agent Details - Other Agents By Creator
|
|
||||||
- Agent Details - Similar Agents
|
|
||||||
- Creator Details - Agents By Creator
|
|
||||||
"""
|
|
||||||
if page < 1:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=422, detail="Page must be greater than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
if page_size < 1:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=422, detail="Page size must be greater than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
agents = await store_cache._get_cached_store_agents(
|
|
||||||
featured=featured,
|
|
||||||
creator=creator,
|
|
||||||
sorted_by=sorted_by,
|
|
||||||
search_query=search_query,
|
|
||||||
category=category,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
return agents
|
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
############### Search Endpoints #############
|
############### Search Endpoints #############
|
||||||
##############################################
|
##############################################
|
||||||
@@ -158,60 +71,30 @@ async def get_agents(
|
|||||||
"/search",
|
"/search",
|
||||||
summary="Unified search across all content types",
|
summary="Unified search across all content types",
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
response_model=store_model.UnifiedSearchResponse,
|
|
||||||
)
|
)
|
||||||
async def unified_search(
|
async def unified_search(
|
||||||
query: str,
|
query: str,
|
||||||
content_types: list[str] | None = fastapi.Query(
|
content_types: list[prisma.enums.ContentType] | None = Query(
|
||||||
default=None,
|
default=None,
|
||||||
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
|
description="Content types to search. If not specified, searches all.",
|
||||||
),
|
),
|
||||||
page: int = 1,
|
page: int = Query(ge=1, default=1),
|
||||||
page_size: int = 20,
|
page_size: int = Query(ge=1, default=20),
|
||||||
user_id: str | None = fastapi.Security(
|
user_id: str | None = Security(
|
||||||
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
||||||
),
|
),
|
||||||
):
|
) -> store_model.UnifiedSearchResponse:
|
||||||
"""
|
"""
|
||||||
Search across all content types (store agents, blocks, documentation) using hybrid search.
|
Search across all content types (marketplace agents, blocks, documentation)
|
||||||
|
using hybrid search.
|
||||||
|
|
||||||
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The search query string
|
|
||||||
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
|
|
||||||
page: Page number for pagination (default 1)
|
|
||||||
page_size: Number of results per page (default 20)
|
|
||||||
user_id: Optional authenticated user ID (for user-scoped content in future)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UnifiedSearchResponse: Paginated list of search results with relevance scores
|
|
||||||
"""
|
"""
|
||||||
if page < 1:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=422, detail="Page must be greater than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
if page_size < 1:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=422, detail="Page size must be greater than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert string content types to enum
|
|
||||||
content_type_enums: list[prisma.enums.ContentType] | None = None
|
|
||||||
if content_types:
|
|
||||||
try:
|
|
||||||
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
|
|
||||||
except ValueError as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=422,
|
|
||||||
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform unified hybrid search
|
# Perform unified hybrid search
|
||||||
results, total = await store_hybrid_search.unified_hybrid_search(
|
results, total = await store_hybrid_search.unified_hybrid_search(
|
||||||
query=query,
|
query=query,
|
||||||
content_types=content_type_enums,
|
content_types=content_types,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
@@ -245,22 +128,69 @@ async def unified_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
##############################################
|
||||||
|
############### Agent Endpoints ##############
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents",
|
||||||
|
summary="List store agents",
|
||||||
|
tags=["store", "public"],
|
||||||
|
)
|
||||||
|
async def get_agents(
|
||||||
|
featured: bool = Query(
|
||||||
|
default=False, description="Filter to only show featured agents"
|
||||||
|
),
|
||||||
|
creator: str | None = Query(
|
||||||
|
default=None, description="Filter agents by creator username"
|
||||||
|
),
|
||||||
|
category: str | None = Query(default=None, description="Filter agents by category"),
|
||||||
|
search_query: str | None = Query(
|
||||||
|
default=None, description="Literal + semantic search on names and descriptions"
|
||||||
|
),
|
||||||
|
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
|
||||||
|
default=None,
|
||||||
|
description="Property to sort results by. Ignored if search_query is provided.",
|
||||||
|
),
|
||||||
|
page: int = Query(ge=1, default=1),
|
||||||
|
page_size: int = Query(ge=1, default=20),
|
||||||
|
) -> store_model.StoreAgentsResponse:
|
||||||
|
"""
|
||||||
|
Get a paginated list of agents from the marketplace,
|
||||||
|
with optional filtering and sorting.
|
||||||
|
|
||||||
|
Used for:
|
||||||
|
- Home Page Featured Agents
|
||||||
|
- Home Page Top Agents
|
||||||
|
- Search Results
|
||||||
|
- Agent Details - Other Agents By Creator
|
||||||
|
- Agent Details - Similar Agents
|
||||||
|
- Creator Details - Agents By Creator
|
||||||
|
"""
|
||||||
|
agents = await store_cache._get_cached_store_agents(
|
||||||
|
featured=featured,
|
||||||
|
creator=creator,
|
||||||
|
sorted_by=sorted_by,
|
||||||
|
search_query=search_query,
|
||||||
|
category=category,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/agents/{username}/{agent_name}",
|
"/agents/{username}/{agent_name}",
|
||||||
summary="Get specific agent",
|
summary="Get specific agent",
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
response_model=store_model.StoreAgentDetails,
|
|
||||||
)
|
)
|
||||||
async def get_agent(
|
async def get_agent_by_name(
|
||||||
username: str,
|
username: str,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
include_changelog: bool = fastapi.Query(default=False),
|
include_changelog: bool = Query(default=False),
|
||||||
):
|
) -> store_model.StoreAgentDetails:
|
||||||
"""
|
"""Get details of a marketplace agent"""
|
||||||
This is only used on the AgentDetails Page.
|
|
||||||
|
|
||||||
It returns the store listing agents details.
|
|
||||||
"""
|
|
||||||
username = urllib.parse.unquote(username).lower()
|
username = urllib.parse.unquote(username).lower()
|
||||||
# URL decode the agent name since it comes from the URL path
|
# URL decode the agent name since it comes from the URL path
|
||||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||||
@@ -270,76 +200,82 @@ async def get_agent(
|
|||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/graph/{store_listing_version_id}",
|
|
||||||
summary="Get agent graph",
|
|
||||||
tags=["store"],
|
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
|
||||||
)
|
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
|
||||||
store_listing_version_id: str,
|
|
||||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
|
||||||
"""
|
|
||||||
Get Agent Graph from Store Listing Version ID.
|
|
||||||
"""
|
|
||||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
|
||||||
return graph
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/agents/{store_listing_version_id}",
|
|
||||||
summary="Get agent by version",
|
|
||||||
tags=["store"],
|
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
|
||||||
response_model=store_model.StoreAgentDetails,
|
|
||||||
)
|
|
||||||
async def get_store_agent(store_listing_version_id: str):
|
|
||||||
"""
|
|
||||||
Get Store Agent Details from Store Listing Version ID.
|
|
||||||
"""
|
|
||||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/agents/{username}/{agent_name}/review",
|
"/agents/{username}/{agent_name}/review",
|
||||||
summary="Create agent review",
|
summary="Create agent review",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
response_model=store_model.StoreReview,
|
|
||||||
)
|
)
|
||||||
async def create_review(
|
async def post_user_review_for_agent(
|
||||||
username: str,
|
username: str,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
review: store_model.StoreReviewCreate,
|
review: store_model.StoreReviewCreate,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
):
|
) -> store_model.StoreReview:
|
||||||
"""
|
"""Post a user review on a marketplace agent listing"""
|
||||||
Create a review for a store agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
username: Creator's username
|
|
||||||
agent_name: Name/slug of the agent
|
|
||||||
review: Review details including score and optional comments
|
|
||||||
user_id: ID of authenticated user creating the review
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created review
|
|
||||||
"""
|
|
||||||
username = urllib.parse.unquote(username).lower()
|
username = urllib.parse.unquote(username).lower()
|
||||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||||
# Create the review
|
|
||||||
created_review = await store_db.create_store_review(
|
created_review = await store_db.create_store_review(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
store_listing_version_id=review.store_listing_version_id,
|
store_listing_version_id=review.store_listing_version_id,
|
||||||
score=review.score,
|
score=review.score,
|
||||||
comments=review.comments,
|
comments=review.comments,
|
||||||
)
|
)
|
||||||
|
|
||||||
return created_review
|
return created_review
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/listings/versions/{store_listing_version_id}",
|
||||||
|
summary="Get agent by version",
|
||||||
|
tags=["store"],
|
||||||
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
|
)
|
||||||
|
async def get_agent_by_listing_version(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
) -> store_model.StoreAgentDetails:
|
||||||
|
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/listings/versions/{store_listing_version_id}/graph",
|
||||||
|
summary="Get agent graph",
|
||||||
|
tags=["store"],
|
||||||
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
|
)
|
||||||
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||||
|
"""Get outline of graph belonging to a specific marketplace listing version"""
|
||||||
|
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/listings/versions/{store_listing_version_id}/graph/download",
|
||||||
|
summary="Download agent file",
|
||||||
|
tags=["store", "public"],
|
||||||
|
)
|
||||||
|
async def download_agent_file(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
) -> fastapi.responses.FileResponse:
|
||||||
|
"""Download agent graph file for a specific marketplace listing version"""
|
||||||
|
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||||
|
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||||
|
|
||||||
|
# Sending graph as a stream (similar to marketplace v1)
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", suffix=".json", delete=False
|
||||||
|
) as tmp_file:
|
||||||
|
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||||
|
tmp_file.flush()
|
||||||
|
|
||||||
|
return fastapi.responses.FileResponse(
|
||||||
|
tmp_file.name, filename=file_name, media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
############# Creator Endpoints #############
|
############# Creator Endpoints #############
|
||||||
##############################################
|
##############################################
|
||||||
@@ -349,37 +285,19 @@ async def create_review(
|
|||||||
"/creators",
|
"/creators",
|
||||||
summary="List store creators",
|
summary="List store creators",
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
response_model=store_model.CreatorsResponse,
|
|
||||||
)
|
)
|
||||||
async def get_creators(
|
async def get_creators(
|
||||||
featured: bool = False,
|
featured: bool = Query(
|
||||||
search_query: str | None = None,
|
default=False, description="Filter to only show featured creators"
|
||||||
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
),
|
||||||
page: int = 1,
|
search_query: str | None = Query(
|
||||||
page_size: int = 20,
|
default=None, description="Literal + semantic search on names and descriptions"
|
||||||
):
|
),
|
||||||
"""
|
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
||||||
This is needed for:
|
page: int = Query(ge=1, default=1),
|
||||||
- Home Page Featured Creators
|
page_size: int = Query(ge=1, default=20),
|
||||||
- Search Results Page
|
) -> store_model.CreatorsResponse:
|
||||||
|
"""List or search marketplace creators"""
|
||||||
---
|
|
||||||
|
|
||||||
To support this functionality we need:
|
|
||||||
- featured: bool - to limit the list to just featured agents
|
|
||||||
- search_query: str - vector search based on the creators profile description.
|
|
||||||
- sorted_by: [agent_rating, agent_runs] -
|
|
||||||
"""
|
|
||||||
if page < 1:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=422, detail="Page must be greater than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
if page_size < 1:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=422, detail="Page size must be greater than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
creators = await store_cache._get_cached_store_creators(
|
creators = await store_cache._get_cached_store_creators(
|
||||||
featured=featured,
|
featured=featured,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
@@ -391,18 +309,12 @@ async def get_creators(
|
|||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/creator/{username}",
|
"/creators/{username}",
|
||||||
summary="Get creator details",
|
summary="Get creator details",
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
response_model=store_model.CreatorDetails,
|
|
||||||
)
|
)
|
||||||
async def get_creator(
|
async def get_creator(username: str) -> store_model.CreatorDetails:
|
||||||
username: str,
|
"""Get details on a marketplace creator"""
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get the details of a creator.
|
|
||||||
- Creator Details Page
|
|
||||||
"""
|
|
||||||
username = urllib.parse.unquote(username).lower()
|
username = urllib.parse.unquote(username).lower()
|
||||||
creator = await store_cache._get_cached_creator_details(username=username)
|
creator = await store_cache._get_cached_creator_details(username=username)
|
||||||
return creator
|
return creator
|
||||||
@@ -414,20 +326,17 @@ async def get_creator(
|
|||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/myagents",
|
"/my-unpublished-agents",
|
||||||
summary="Get my agents",
|
summary="Get my agents",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
response_model=store_model.MyAgentsResponse,
|
|
||||||
)
|
)
|
||||||
async def get_my_agents(
|
async def get_my_unpublished_agents(
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
page: int = Query(ge=1, default=1),
|
||||||
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
page_size: int = Query(ge=1, default=20),
|
||||||
):
|
) -> store_model.MyUnpublishedAgentsResponse:
|
||||||
"""
|
"""List the authenticated user's unpublished agents"""
|
||||||
Get user's own agents.
|
|
||||||
"""
|
|
||||||
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
@@ -436,28 +345,17 @@ async def get_my_agents(
|
|||||||
"/submissions/{submission_id}",
|
"/submissions/{submission_id}",
|
||||||
summary="Delete store submission",
|
summary="Delete store submission",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
response_model=bool,
|
|
||||||
)
|
)
|
||||||
async def delete_submission(
|
async def delete_submission(
|
||||||
submission_id: str,
|
submission_id: str,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
):
|
) -> bool:
|
||||||
"""
|
"""Delete a marketplace listing submission"""
|
||||||
Delete a store listing submission.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id (str): ID of the authenticated user
|
|
||||||
submission_id (str): ID of the submission to be deleted
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the submission was successfully deleted, False otherwise
|
|
||||||
"""
|
|
||||||
result = await store_db.delete_store_submission(
|
result = await store_db.delete_store_submission(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
submission_id=submission_id,
|
submission_id=submission_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -465,37 +363,14 @@ async def delete_submission(
|
|||||||
"/submissions",
|
"/submissions",
|
||||||
summary="List my submissions",
|
summary="List my submissions",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
response_model=store_model.StoreSubmissionsResponse,
|
|
||||||
)
|
)
|
||||||
async def get_submissions(
|
async def get_submissions(
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
page: int = 1,
|
page: int = Query(ge=1, default=1),
|
||||||
page_size: int = 20,
|
page_size: int = Query(ge=1, default=20),
|
||||||
):
|
) -> store_model.StoreSubmissionsResponse:
|
||||||
"""
|
"""List the authenticated user's marketplace listing submissions"""
|
||||||
Get a paginated list of store submissions for the authenticated user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id (str): ID of the authenticated user
|
|
||||||
page (int, optional): Page number for pagination. Defaults to 1.
|
|
||||||
page_size (int, optional): Number of submissions per page. Defaults to 20.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StoreListingsResponse: Paginated list of store submissions
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If page or page_size are less than 1
|
|
||||||
"""
|
|
||||||
if page < 1:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=422, detail="Page must be greater than 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
if page_size < 1:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=422, detail="Page size must be greater than 0"
|
|
||||||
)
|
|
||||||
listings = await store_db.get_store_submissions(
|
listings = await store_db.get_store_submissions(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
page=page,
|
page=page,
|
||||||
@@ -508,30 +383,17 @@ async def get_submissions(
|
|||||||
"/submissions",
|
"/submissions",
|
||||||
summary="Create store submission",
|
summary="Create store submission",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
response_model=store_model.StoreSubmission,
|
|
||||||
)
|
)
|
||||||
async def create_submission(
|
async def create_submission(
|
||||||
submission_request: store_model.StoreSubmissionRequest,
|
submission_request: store_model.StoreSubmissionRequest,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
):
|
) -> store_model.StoreSubmission:
|
||||||
"""
|
"""Submit a new marketplace listing for review"""
|
||||||
Create a new store listing submission.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
submission_request (StoreSubmissionRequest): The submission details
|
|
||||||
user_id (str): ID of the authenticated user submitting the listing
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StoreSubmission: The created store submission
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If there is an error creating the submission
|
|
||||||
"""
|
|
||||||
result = await store_db.create_store_submission(
|
result = await store_db.create_store_submission(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
agent_id=submission_request.agent_id,
|
graph_id=submission_request.graph_id,
|
||||||
agent_version=submission_request.agent_version,
|
graph_version=submission_request.graph_version,
|
||||||
slug=submission_request.slug,
|
slug=submission_request.slug,
|
||||||
name=submission_request.name,
|
name=submission_request.name,
|
||||||
video_url=submission_request.video_url,
|
video_url=submission_request.video_url,
|
||||||
@@ -544,7 +406,6 @@ async def create_submission(
|
|||||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -552,28 +413,14 @@ async def create_submission(
|
|||||||
"/submissions/{store_listing_version_id}",
|
"/submissions/{store_listing_version_id}",
|
||||||
summary="Edit store submission",
|
summary="Edit store submission",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
response_model=store_model.StoreSubmission,
|
|
||||||
)
|
)
|
||||||
async def edit_submission(
|
async def edit_submission(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
submission_request: store_model.StoreSubmissionEditRequest,
|
submission_request: store_model.StoreSubmissionEditRequest,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
):
|
) -> store_model.StoreSubmission:
|
||||||
"""
|
"""Update a pending marketplace listing submission"""
|
||||||
Edit an existing store listing submission.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
store_listing_version_id (str): ID of the store listing version to edit
|
|
||||||
submission_request (StoreSubmissionRequest): The updated submission details
|
|
||||||
user_id (str): ID of the authenticated user editing the listing
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StoreSubmission: The updated store submission
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If there is an error editing the submission
|
|
||||||
"""
|
|
||||||
result = await store_db.edit_store_submission(
|
result = await store_db.edit_store_submission(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
store_listing_version_id=store_listing_version_id,
|
store_listing_version_id=store_listing_version_id,
|
||||||
@@ -588,7 +435,6 @@ async def edit_submission(
|
|||||||
changes_summary=submission_request.changes_summary,
|
changes_summary=submission_request.changes_summary,
|
||||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -596,115 +442,61 @@ async def edit_submission(
|
|||||||
"/submissions/media",
|
"/submissions/media",
|
||||||
summary="Upload submission media",
|
summary="Upload submission media",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
)
|
)
|
||||||
async def upload_submission_media(
|
async def upload_submission_media(
|
||||||
file: fastapi.UploadFile,
|
file: fastapi.UploadFile,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
):
|
) -> str:
|
||||||
"""
|
"""Upload media for a marketplace listing submission"""
|
||||||
Upload media (images/videos) for a store listing submission.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (UploadFile): The media file to upload
|
|
||||||
user_id (str): ID of the authenticated user uploading the media
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: URL of the uploaded media file
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If there is an error uploading the media
|
|
||||||
"""
|
|
||||||
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
||||||
return media_url
|
return media_url
|
||||||
|
|
||||||
|
|
||||||
|
class ImageURLResponse(BaseModel):
|
||||||
|
image_url: str
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/submissions/generate_image",
|
"/submissions/generate_image",
|
||||||
summary="Generate submission image",
|
summary="Generate submission image",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
||||||
)
|
)
|
||||||
async def generate_image(
|
async def generate_image(
|
||||||
agent_id: str,
|
graph_id: str,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> fastapi.responses.Response:
|
) -> ImageURLResponse:
|
||||||
"""
|
"""
|
||||||
Generate an image for a store listing submission.
|
Generate an image for a marketplace listing submission based on the properties
|
||||||
|
of a given graph.
|
||||||
Args:
|
|
||||||
agent_id (str): ID of the agent to generate an image for
|
|
||||||
user_id (str): ID of the authenticated user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSONResponse: JSON containing the URL of the generated image
|
|
||||||
"""
|
"""
|
||||||
agent = await backend.data.graph.get_graph(
|
graph = await backend.data.graph.get_graph(
|
||||||
graph_id=agent_id, version=None, user_id=user_id
|
graph_id=graph_id, version=None, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not agent:
|
if not graph:
|
||||||
raise fastapi.HTTPException(
|
raise NotFoundError(f"Agent graph #{graph_id} not found")
|
||||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
|
||||||
)
|
|
||||||
# Use .jpeg here since we are generating JPEG images
|
# Use .jpeg here since we are generating JPEG images
|
||||||
filename = f"agent_{agent_id}.jpeg"
|
filename = f"agent_{graph_id}.jpeg"
|
||||||
|
|
||||||
existing_url = await store_media.check_media_exists(user_id, filename)
|
existing_url = await store_media.check_media_exists(user_id, filename)
|
||||||
if existing_url:
|
if existing_url:
|
||||||
logger.info(f"Using existing image for agent {agent_id}")
|
logger.info(f"Using existing image for agent graph {graph_id}")
|
||||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
return ImageURLResponse(image_url=existing_url)
|
||||||
# Generate agent image as JPEG
|
# Generate agent image as JPEG
|
||||||
image = await store_image_gen.generate_agent_image(agent=agent)
|
image = await store_image_gen.generate_agent_image(agent=graph)
|
||||||
|
|
||||||
# Create UploadFile with the correct filename and content_type
|
# Create UploadFile with the correct filename and content_type
|
||||||
image_file = fastapi.UploadFile(
|
image_file = fastapi.UploadFile(
|
||||||
file=image,
|
file=image,
|
||||||
filename=filename,
|
filename=filename,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_url = await store_media.upload_media(
|
image_url = await store_media.upload_media(
|
||||||
user_id=user_id, file=image_file, use_file_name=True
|
user_id=user_id, file=image_file, use_file_name=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
return ImageURLResponse(image_url=image_url)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/download/agents/{store_listing_version_id}",
|
|
||||||
summary="Download agent file",
|
|
||||||
tags=["store", "public"],
|
|
||||||
)
|
|
||||||
async def download_agent_file(
|
|
||||||
store_listing_version_id: str = fastapi.Path(
|
|
||||||
..., description="The ID of the agent to download"
|
|
||||||
),
|
|
||||||
) -> fastapi.responses.FileResponse:
|
|
||||||
"""
|
|
||||||
Download the agent file by streaming its content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
store_listing_version_id (str): The ID of the agent to download
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: A streaming response containing the agent's graph data.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
|
||||||
"""
|
|
||||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
|
||||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
|
||||||
|
|
||||||
# Sending graph as a stream (similar to marketplace v1)
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w", suffix=".json", delete=False
|
|
||||||
) as tmp_file:
|
|
||||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
|
||||||
tmp_file.flush()
|
|
||||||
|
|
||||||
return fastapi.responses.FileResponse(
|
|
||||||
tmp_file.name, filename=file_name, media_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import pytest
|
|||||||
import pytest_mock
|
import pytest_mock
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
|
|
||||||
|
from backend.api.features.store.db import StoreAgentsSortOptions
|
||||||
|
|
||||||
from . import model as store_model
|
from . import model as store_model
|
||||||
from . import routes as store_routes
|
from . import routes as store_routes
|
||||||
|
|
||||||
@@ -196,7 +198,7 @@ def test_get_agents_sorted(
|
|||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
featured=False,
|
featured=False,
|
||||||
creators=None,
|
creators=None,
|
||||||
sorted_by="runs",
|
sorted_by=StoreAgentsSortOptions.RUNS,
|
||||||
search_query=None,
|
search_query=None,
|
||||||
category=None,
|
category=None,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -380,9 +382,11 @@ def test_get_agent_details(
|
|||||||
runs=100,
|
runs=100,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
versions=["1.0.0", "1.1.0"],
|
versions=["1.0.0", "1.1.0"],
|
||||||
agentGraphVersions=["1", "2"],
|
graph_versions=["1", "2"],
|
||||||
agentGraphId="test-graph-id",
|
graph_id="test-graph-id",
|
||||||
last_updated=FIXED_NOW,
|
last_updated=FIXED_NOW,
|
||||||
|
active_version_id="test-version-id",
|
||||||
|
has_approved_version=True,
|
||||||
)
|
)
|
||||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
||||||
mock_db_call.return_value = mocked_value
|
mock_db_call.return_value = mocked_value
|
||||||
@@ -435,15 +439,17 @@ def test_get_creators_pagination(
|
|||||||
) -> None:
|
) -> None:
|
||||||
mocked_value = store_model.CreatorsResponse(
|
mocked_value = store_model.CreatorsResponse(
|
||||||
creators=[
|
creators=[
|
||||||
store_model.Creator(
|
store_model.CreatorDetails(
|
||||||
name=f"Creator {i}",
|
name=f"Creator {i}",
|
||||||
username=f"creator{i}",
|
username=f"creator{i}",
|
||||||
description=f"Creator {i} description",
|
|
||||||
avatar_url=f"avatar{i}.jpg",
|
avatar_url=f"avatar{i}.jpg",
|
||||||
num_agents=1,
|
description=f"Creator {i} description",
|
||||||
agent_rating=4.5,
|
links=[f"user{i}.link.com"],
|
||||||
agent_runs=100,
|
|
||||||
is_featured=False,
|
is_featured=False,
|
||||||
|
num_agents=1,
|
||||||
|
agent_runs=100,
|
||||||
|
agent_rating=4.5,
|
||||||
|
top_categories=["cat1", "cat2", "cat3"],
|
||||||
)
|
)
|
||||||
for i in range(5)
|
for i in range(5)
|
||||||
],
|
],
|
||||||
@@ -496,19 +502,19 @@ def test_get_creator_details(
|
|||||||
mocked_value = store_model.CreatorDetails(
|
mocked_value = store_model.CreatorDetails(
|
||||||
name="Test User",
|
name="Test User",
|
||||||
username="creator1",
|
username="creator1",
|
||||||
|
avatar_url="avatar.jpg",
|
||||||
description="Test creator description",
|
description="Test creator description",
|
||||||
links=["link1.com", "link2.com"],
|
links=["link1.com", "link2.com"],
|
||||||
avatar_url="avatar.jpg",
|
is_featured=True,
|
||||||
agent_rating=4.8,
|
num_agents=5,
|
||||||
agent_runs=1000,
|
agent_runs=1000,
|
||||||
|
agent_rating=4.8,
|
||||||
top_categories=["category1", "category2"],
|
top_categories=["category1", "category2"],
|
||||||
)
|
)
|
||||||
mock_db_call = mocker.patch(
|
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
|
||||||
"backend.api.features.store.db.get_store_creator_details"
|
|
||||||
)
|
|
||||||
mock_db_call.return_value = mocked_value
|
mock_db_call.return_value = mocked_value
|
||||||
|
|
||||||
response = client.get("/creator/creator1")
|
response = client.get("/creators/creator1")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
data = store_model.CreatorDetails.model_validate(response.json())
|
data = store_model.CreatorDetails.model_validate(response.json())
|
||||||
@@ -528,19 +534,26 @@ def test_get_submissions_success(
|
|||||||
submissions=[
|
submissions=[
|
||||||
store_model.StoreSubmission(
|
store_model.StoreSubmission(
|
||||||
listing_id="test-listing-id",
|
listing_id="test-listing-id",
|
||||||
name="Test Agent",
|
user_id="test-user-id",
|
||||||
description="Test agent description",
|
|
||||||
image_urls=["test.jpg"],
|
|
||||||
date_submitted=FIXED_NOW,
|
|
||||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
|
||||||
runs=50,
|
|
||||||
rating=4.2,
|
|
||||||
agent_id="test-agent-id",
|
|
||||||
agent_version=1,
|
|
||||||
sub_heading="Test agent subheading",
|
|
||||||
slug="test-agent",
|
slug="test-agent",
|
||||||
video_url="test.mp4",
|
listing_version_id="test-version-id",
|
||||||
|
listing_version=1,
|
||||||
|
graph_id="test-agent-id",
|
||||||
|
graph_version=1,
|
||||||
|
name="Test Agent",
|
||||||
|
sub_heading="Test agent subheading",
|
||||||
|
description="Test agent description",
|
||||||
|
instructions="Click the button!",
|
||||||
categories=["test-category"],
|
categories=["test-category"],
|
||||||
|
image_urls=["test.jpg"],
|
||||||
|
video_url="test.mp4",
|
||||||
|
agent_output_demo_url="demo_video.mp4",
|
||||||
|
submitted_at=FIXED_NOW,
|
||||||
|
changes_summary="Initial Submission",
|
||||||
|
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||||
|
run_count=50,
|
||||||
|
review_count=5,
|
||||||
|
review_avg_rating=4.2,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import pytest
|
|||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
from . import cache as store_cache
|
from . import cache as store_cache
|
||||||
|
from .db import StoreAgentsSortOptions
|
||||||
from .model import StoreAgent, StoreAgentsResponse
|
from .model import StoreAgent, StoreAgentsResponse
|
||||||
|
|
||||||
|
|
||||||
@@ -215,7 +216,7 @@ class TestCacheDeletion:
|
|||||||
await store_cache._get_cached_store_agents(
|
await store_cache._get_cached_store_agents(
|
||||||
featured=True,
|
featured=True,
|
||||||
creator="testuser",
|
creator="testuser",
|
||||||
sorted_by="rating",
|
sorted_by=StoreAgentsSortOptions.RATING,
|
||||||
search_query="AI assistant",
|
search_query="AI assistant",
|
||||||
category="productivity",
|
category="productivity",
|
||||||
page=2,
|
page=2,
|
||||||
@@ -227,7 +228,7 @@ class TestCacheDeletion:
|
|||||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||||
featured=True,
|
featured=True,
|
||||||
creator="testuser",
|
creator="testuser",
|
||||||
sorted_by="rating",
|
sorted_by=StoreAgentsSortOptions.RATING,
|
||||||
search_query="AI assistant",
|
search_query="AI assistant",
|
||||||
category="productivity",
|
category="productivity",
|
||||||
page=2,
|
page=2,
|
||||||
@@ -239,7 +240,7 @@ class TestCacheDeletion:
|
|||||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||||
featured=True,
|
featured=True,
|
||||||
creator="testuser",
|
creator="testuser",
|
||||||
sorted_by="rating",
|
sorted_by=StoreAgentsSortOptions.RATING,
|
||||||
search_query="AI assistant",
|
search_query="AI assistant",
|
||||||
category="productivity",
|
category="productivity",
|
||||||
page=2,
|
page=2,
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from backend.data.credit import (
|
|||||||
set_auto_top_up,
|
set_auto_top_up,
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphSettings
|
from backend.data.graph import GraphSettings
|
||||||
|
from backend.data.invited_user import get_or_activate_user
|
||||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||||
from backend.data.onboarding import (
|
from backend.data.onboarding import (
|
||||||
@@ -70,7 +71,6 @@ from backend.data.onboarding import (
|
|||||||
update_user_onboarding,
|
update_user_onboarding,
|
||||||
)
|
)
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
get_or_create_user,
|
|
||||||
get_user_by_id,
|
get_user_by_id,
|
||||||
get_user_notification_preference,
|
get_user_notification_preference,
|
||||||
update_user_email,
|
update_user_email,
|
||||||
@@ -126,6 +126,9 @@ v1_router = APIRouter()
|
|||||||
########################################################
|
########################################################
|
||||||
|
|
||||||
|
|
||||||
|
_tally_background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
"/auth/user",
|
"/auth/user",
|
||||||
summary="Get or create user",
|
summary="Get or create user",
|
||||||
@@ -133,7 +136,23 @@ v1_router = APIRouter()
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||||
user = await get_or_create_user(user_data)
|
user = await get_or_activate_user(user_data)
|
||||||
|
|
||||||
|
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
||||||
|
# not produce a stored result before first activation.
|
||||||
|
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
||||||
|
if age_seconds < 30:
|
||||||
|
try:
|
||||||
|
from backend.data.tally import populate_understanding_from_tally
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
populate_understanding_from_tally(user.id, user.email)
|
||||||
|
)
|
||||||
|
_tally_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_tally_background_tasks.discard)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to start Tally population task", exc_info=True)
|
||||||
|
|
||||||
return user.model_dump()
|
return user.model_dump()
|
||||||
|
|
||||||
|
|
||||||
@@ -144,7 +163,8 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def update_user_email_route(
|
async def update_user_email_route(
|
||||||
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
email: str = Body(...),
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
await update_user_email(user_id, email)
|
await update_user_email(user_id, email)
|
||||||
|
|
||||||
@@ -158,10 +178,16 @@ async def update_user_email_route(
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def get_user_timezone_route(
|
async def get_user_timezone_route(
|
||||||
user_data: dict = Security(get_jwt_payload),
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> TimezoneResponse:
|
) -> TimezoneResponse:
|
||||||
"""Get user timezone setting."""
|
"""Get user timezone setting."""
|
||||||
user = await get_or_create_user(user_data)
|
try:
|
||||||
|
user = await get_user_by_id(user_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found. Please complete activation via /auth/user first.",
|
||||||
|
)
|
||||||
return TimezoneResponse(timezone=user.timezone)
|
return TimezoneResponse(timezone=user.timezone)
|
||||||
|
|
||||||
|
|
||||||
@@ -172,7 +198,8 @@ async def get_user_timezone_route(
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def update_user_timezone_route(
|
async def update_user_timezone_route(
|
||||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
request: UpdateTimezoneRequest,
|
||||||
) -> TimezoneResponse:
|
) -> TimezoneResponse:
|
||||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||||
user = await update_user_timezone(user_id, str(request.timezone))
|
user = await update_user_timezone(user_id, str(request.timezone))
|
||||||
@@ -428,7 +455,6 @@ async def execute_graph_block(
|
|||||||
async def upload_file(
|
async def upload_file(
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
provider: str = "gcs",
|
|
||||||
expiration_hours: int = 24,
|
expiration_hours: int = 24,
|
||||||
) -> UploadFileResponse:
|
) -> UploadFileResponse:
|
||||||
"""
|
"""
|
||||||
@@ -491,7 +517,6 @@ async def upload_file(
|
|||||||
storage_path = await cloud_storage.store_file(
|
storage_path = await cloud_storage.store_file(
|
||||||
content=content,
|
content=content,
|
||||||
filename=file_name,
|
filename=file_name,
|
||||||
provider=provider,
|
|
||||||
expiration_hours=expiration_hours,
|
expiration_hours=expiration_hours,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
@@ -43,6 +43,7 @@ def test_get_or_create_user_route(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test get or create user endpoint"""
|
"""Test get or create user endpoint"""
|
||||||
mock_user = Mock()
|
mock_user = Mock()
|
||||||
|
mock_user.created_at = datetime.now(timezone.utc)
|
||||||
mock_user.model_dump.return_value = {
|
mock_user.model_dump.return_value = {
|
||||||
"id": test_user_id,
|
"id": test_user_id,
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
@@ -50,7 +51,7 @@ def test_get_or_create_user_route(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"backend.api.features.v1.get_or_create_user",
|
"backend.api.features.v1.get_or_activate_user",
|
||||||
return_value=mock_user,
|
return_value=mock_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -514,7 +515,6 @@ async def test_upload_file_success(test_user_id: str):
|
|||||||
result = await upload_file(
|
result = await upload_file(
|
||||||
file=upload_file_mock,
|
file=upload_file_mock,
|
||||||
user_id=test_user_id,
|
user_id=test_user_id,
|
||||||
provider="gcs",
|
|
||||||
expiration_hours=24,
|
expiration_hours=24,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -532,7 +532,6 @@ async def test_upload_file_success(test_user_id: str):
|
|||||||
mock_handler.store_file.assert_called_once_with(
|
mock_handler.store_file.assert_called_once_with(
|
||||||
content=file_content,
|
content=file_content,
|
||||||
filename="test.txt",
|
filename="test.txt",
|
||||||
provider="gcs",
|
|
||||||
expiration_hours=24,
|
expiration_hours=24,
|
||||||
user_id=test_user_id,
|
user_id=test_user_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,15 +3,29 @@ Workspace API routes for managing user file storage.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
|
from fastapi import Query, UploadFile
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.workspace import get_workspace, get_workspace_file
|
from backend.data.workspace import (
|
||||||
|
WorkspaceFile,
|
||||||
|
count_workspace_files,
|
||||||
|
get_or_create_workspace,
|
||||||
|
get_workspace,
|
||||||
|
get_workspace_file,
|
||||||
|
get_workspace_total_size,
|
||||||
|
soft_delete_workspace_file,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
|
||||||
@@ -44,11 +58,11 @@ router = fastapi.APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_streaming_response(content: bytes, file) -> Response:
|
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||||
"""Create a streaming response for file content."""
|
"""Create a streaming response for file content."""
|
||||||
return Response(
|
return Response(
|
||||||
content=content,
|
content=content,
|
||||||
media_type=file.mimeType,
|
media_type=file.mime_type,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||||
"Content-Length": str(len(content)),
|
"Content-Length": str(len(content)),
|
||||||
@@ -56,7 +70,7 @@ def _create_streaming_response(content: bytes, file) -> Response:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _create_file_download_response(file) -> Response:
|
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||||
"""
|
"""
|
||||||
Create a download response for a workspace file.
|
Create a download response for a workspace file.
|
||||||
|
|
||||||
@@ -66,38 +80,57 @@ async def _create_file_download_response(file) -> Response:
|
|||||||
storage = await get_workspace_storage()
|
storage = await get_workspace_storage()
|
||||||
|
|
||||||
# For local storage, stream the file directly
|
# For local storage, stream the file directly
|
||||||
if file.storagePath.startswith("local://"):
|
if file.storage_path.startswith("local://"):
|
||||||
content = await storage.retrieve(file.storagePath)
|
content = await storage.retrieve(file.storage_path)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file)
|
||||||
|
|
||||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||||
try:
|
try:
|
||||||
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
url = await storage.get_download_url(file.storage_path, expires_in=300)
|
||||||
# If we got back an API path (fallback), stream directly instead
|
# If we got back an API path (fallback), stream directly instead
|
||||||
if url.startswith("/api/"):
|
if url.startswith("/api/"):
|
||||||
content = await storage.retrieve(file.storagePath)
|
content = await storage.retrieve(file.storage_path)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file)
|
||||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the signed URL failure with context
|
# Log the signed URL failure with context
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to get signed URL for file {file.id} "
|
f"Failed to get signed URL for file {file.id} "
|
||||||
f"(storagePath={file.storagePath}): {e}",
|
f"(storagePath={file.storage_path}): {e}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# Fall back to streaming directly from GCS
|
# Fall back to streaming directly from GCS
|
||||||
try:
|
try:
|
||||||
content = await storage.retrieve(file.storagePath)
|
content = await storage.retrieve(file.storage_path)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file)
|
||||||
except Exception as fallback_error:
|
except Exception as fallback_error:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Fallback streaming also failed for file {file.id} "
|
f"Fallback streaming also failed for file {file.id} "
|
||||||
f"(storagePath={file.storagePath}): {fallback_error}",
|
f"(storagePath={file.storage_path}): {fallback_error}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class UploadFileResponse(BaseModel):
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteFileResponse(BaseModel):
|
||||||
|
deleted: bool
|
||||||
|
|
||||||
|
|
||||||
|
class StorageUsageResponse(BaseModel):
|
||||||
|
used_bytes: int
|
||||||
|
limit_bytes: int
|
||||||
|
used_percent: float
|
||||||
|
file_count: int
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/files/{file_id}/download",
|
"/files/{file_id}/download",
|
||||||
summary="Download file by ID",
|
summary="Download file by ID",
|
||||||
@@ -120,3 +153,148 @@ async def download_file(
|
|||||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
return await _create_file_download_response(file)
|
return await _create_file_download_response(file)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/files/{file_id}",
|
||||||
|
summary="Delete a workspace file",
|
||||||
|
)
|
||||||
|
async def delete_workspace_file(
|
||||||
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
|
file_id: str,
|
||||||
|
) -> DeleteFileResponse:
|
||||||
|
"""
|
||||||
|
Soft-delete a workspace file and attempt to remove it from storage.
|
||||||
|
|
||||||
|
Used when a user clears a file input in the builder.
|
||||||
|
"""
|
||||||
|
workspace = await get_workspace(user_id)
|
||||||
|
if workspace is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||||
|
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id)
|
||||||
|
deleted = await manager.delete_file(file_id)
|
||||||
|
if not deleted:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
return DeleteFileResponse(deleted=True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/files/upload",
|
||||||
|
summary="Upload file to workspace",
|
||||||
|
)
|
||||||
|
async def upload_file(
|
||||||
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
|
file: UploadFile,
|
||||||
|
session_id: str | None = Query(default=None),
|
||||||
|
) -> UploadFileResponse:
|
||||||
|
"""
|
||||||
|
Upload a file to the user's workspace.
|
||||||
|
|
||||||
|
Files are stored in session-scoped paths when session_id is provided,
|
||||||
|
so the agent's session-scoped tools can discover them automatically.
|
||||||
|
"""
|
||||||
|
config = Config()
|
||||||
|
|
||||||
|
# Sanitize filename — strip any directory components
|
||||||
|
filename = os.path.basename(file.filename or "upload") or "upload"
|
||||||
|
|
||||||
|
# Read file content with early abort on size limit
|
||||||
|
max_file_bytes = config.max_file_size_mb * 1024 * 1024
|
||||||
|
chunks: list[bytes] = []
|
||||||
|
total_size = 0
|
||||||
|
while chunk := await file.read(64 * 1024): # 64KB chunks
|
||||||
|
total_size += len(chunk)
|
||||||
|
if total_size > max_file_bytes:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
content = b"".join(chunks)
|
||||||
|
|
||||||
|
# Get or create workspace
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|
||||||
|
# Pre-write storage cap check (soft check — final enforcement is post-write)
|
||||||
|
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||||
|
current_usage = await get_workspace_total_size(workspace.id)
|
||||||
|
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
|
||||||
|
used_percent = (current_usage / storage_limit_bytes) * 100
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail={
|
||||||
|
"message": "Storage limit exceeded",
|
||||||
|
"used_bytes": current_usage,
|
||||||
|
"limit_bytes": storage_limit_bytes,
|
||||||
|
"used_percent": round(used_percent, 1),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warn at 80% usage
|
||||||
|
if (
|
||||||
|
storage_limit_bytes
|
||||||
|
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
|
||||||
|
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Virus scan
|
||||||
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
|
# Write file via WorkspaceManager
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
try:
|
||||||
|
workspace_file = await manager.write_file(content, filename)
|
||||||
|
except ValueError as e:
|
||||||
|
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
||||||
|
|
||||||
|
# Post-write storage check — eliminates TOCTOU race on the quota.
|
||||||
|
# If a concurrent upload pushed us over the limit, undo this write.
|
||||||
|
new_total = await get_workspace_total_size(workspace.id)
|
||||||
|
if storage_limit_bytes and new_total > storage_limit_bytes:
|
||||||
|
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail={
|
||||||
|
"message": "Storage limit exceeded (concurrent upload)",
|
||||||
|
"used_bytes": new_total,
|
||||||
|
"limit_bytes": storage_limit_bytes,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return UploadFileResponse(
|
||||||
|
file_id=workspace_file.id,
|
||||||
|
name=workspace_file.name,
|
||||||
|
path=workspace_file.path,
|
||||||
|
mime_type=workspace_file.mime_type,
|
||||||
|
size_bytes=workspace_file.size_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/storage/usage",
|
||||||
|
summary="Get workspace storage usage",
|
||||||
|
)
|
||||||
|
async def get_storage_usage(
|
||||||
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
|
) -> StorageUsageResponse:
|
||||||
|
"""
|
||||||
|
Get storage usage information for the user's workspace.
|
||||||
|
"""
|
||||||
|
config = Config()
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|
||||||
|
used_bytes = await get_workspace_total_size(workspace.id)
|
||||||
|
file_count = await count_workspace_files(workspace.id)
|
||||||
|
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
||||||
|
|
||||||
|
return StorageUsageResponse(
|
||||||
|
used_bytes=used_bytes,
|
||||||
|
limit_bytes=limit_bytes,
|
||||||
|
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
||||||
|
file_count=file_count,
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,359 @@
|
|||||||
|
"""Tests for workspace file upload and download routes."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
import pytest
|
||||||
|
import pytest_mock
|
||||||
|
|
||||||
|
from backend.api.features.workspace import routes as workspace_routes
|
||||||
|
from backend.data.workspace import WorkspaceFile
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(workspace_routes.router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(ValueError)
|
||||||
|
async def _value_error_handler(
|
||||||
|
request: fastapi.Request, exc: ValueError
|
||||||
|
) -> fastapi.responses.JSONResponse:
|
||||||
|
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
||||||
|
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||||
|
|
||||||
|
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||||
|
|
||||||
|
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
||||||
|
|
||||||
|
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
MOCK_FILE = WorkspaceFile(
|
||||||
|
id="file-aaa-bbb",
|
||||||
|
workspace_id="ws-1",
|
||||||
|
created_at=_NOW,
|
||||||
|
updated_at=_NOW,
|
||||||
|
name="hello.txt",
|
||||||
|
path="/session/hello.txt",
|
||||||
|
mime_type="text/plain",
|
||||||
|
size_bytes=13,
|
||||||
|
storage_path="local://hello.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_app_auth(mock_jwt_user):
|
||||||
|
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||||
|
|
||||||
|
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _upload(
|
||||||
|
filename: str = "hello.txt",
|
||||||
|
content: bytes = b"Hello, world!",
|
||||||
|
content_type: str = "text/plain",
|
||||||
|
):
|
||||||
|
"""Helper to POST a file upload."""
|
||||||
|
return client.post(
|
||||||
|
"/files/upload?session_id=sess-1",
|
||||||
|
files={"file": (filename, io.BytesIO(content), content_type)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Happy path ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||||
|
return_value=0,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.scan_content_safe",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
mock_manager = mocker.MagicMock()
|
||||||
|
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||||
|
return_value=mock_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = _upload()
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["file_id"] == "file-aaa-bbb"
|
||||||
|
assert data["name"] == "hello.txt"
|
||||||
|
assert data["size_bytes"] == 13
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Per-file size limit ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Files larger than max_file_size_mb should be rejected with 413."""
|
||||||
|
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
||||||
|
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
||||||
|
cfg.return_value.max_workspace_storage_mb = 500
|
||||||
|
|
||||||
|
response = _upload(content=b"x" * 1024)
|
||||||
|
assert response.status_code == 413
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Storage quota exceeded ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
# Current usage already at limit
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||||
|
return_value=500 * 1024 * 1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = _upload()
|
||||||
|
assert response.status_code == 413
|
||||||
|
assert "Storage limit exceeded" in response.text
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Post-write quota race (B2) ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
||||||
|
"""If a concurrent upload tips the total over the limit after write,
|
||||||
|
the file should be soft-deleted and 413 returned."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
# Pre-write check passes (under limit), but post-write check fails
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||||
|
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.scan_content_safe",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
mock_manager = mocker.MagicMock()
|
||||||
|
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||||
|
return_value=mock_manager,
|
||||||
|
)
|
||||||
|
mock_delete = mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.soft_delete_workspace_file",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = _upload()
|
||||||
|
assert response.status_code == 413
|
||||||
|
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Any extension accepted (no allowlist) ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Any file extension should be accepted — ClamAV is the security layer."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||||
|
return_value=0,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.scan_content_safe",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
mock_manager = mocker.MagicMock()
|
||||||
|
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||||
|
return_value=mock_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = _upload(filename="data.xyz", content=b"arbitrary")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Virus scan rejection ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
||||||
|
from backend.api.features.store.exceptions import VirusDetectedError
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||||
|
return_value=0,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.scan_content_safe",
|
||||||
|
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
||||||
|
)
|
||||||
|
mock_manager = mocker.MagicMock()
|
||||||
|
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||||
|
return_value=mock_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Virus detected" in response.text
|
||||||
|
mock_manager.write_file.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---- No file extension ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Files without an extension should be accepted and stored as-is."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||||
|
return_value=0,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.scan_content_safe",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
mock_manager = mocker.MagicMock()
|
||||||
|
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||||
|
return_value=mock_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = _upload(
|
||||||
|
filename="Makefile",
|
||||||
|
content=b"all:\n\techo hello",
|
||||||
|
content_type="application/octet-stream",
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_manager.write_file.assert_called_once()
|
||||||
|
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Filename sanitization (SF5) ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Path-traversal filenames should be reduced to their basename."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_or_create_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_total_size",
|
||||||
|
return_value=0,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.scan_content_safe",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
mock_manager = mocker.MagicMock()
|
||||||
|
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||||
|
return_value=mock_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filename with traversal
|
||||||
|
_upload(filename="../../etc/passwd.txt")
|
||||||
|
|
||||||
|
# write_file should have been called with just the basename
|
||||||
|
mock_manager.write_file.assert_called_once()
|
||||||
|
call_args = mock_manager.write_file.call_args
|
||||||
|
assert call_args[0][1] == "passwd.txt"
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Download ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace_file",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/files/some-file-id/download")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Delete ----
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Deleting an existing file should return {"deleted": true}."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
mock_manager = mocker.MagicMock()
|
||||||
|
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||||
|
return_value=mock_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.delete("/files/file-aaa-bbb")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"deleted": True}
|
||||||
|
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Deleting a non-existent file should return 404."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace",
|
||||||
|
return_value=MOCK_WORKSPACE,
|
||||||
|
)
|
||||||
|
mock_manager = mocker.MagicMock()
|
||||||
|
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.WorkspaceManager",
|
||||||
|
return_value=mock_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.delete("/files/nonexistent-id")
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert "File not found" in response.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
||||||
|
"""Deleting when user has no workspace should return 404."""
|
||||||
|
mocker.patch(
|
||||||
|
"backend.api.features.workspace.routes.get_workspace",
|
||||||
|
return_value=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.delete("/files/file-aaa-bbb")
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert "Workspace not found" in response.text
|
||||||
@@ -94,3 +94,8 @@ class NotificationPayload(pydantic.BaseModel):
|
|||||||
|
|
||||||
class OnboardingNotificationPayload(NotificationPayload):
|
class OnboardingNotificationPayload(NotificationPayload):
|
||||||
step: OnboardingStep | None
|
step: OnboardingStep | None
|
||||||
|
|
||||||
|
|
||||||
|
class CopilotCompletionPayload(NotificationPayload):
|
||||||
|
session_id: str
|
||||||
|
status: Literal["completed", "failed"]
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from prisma.errors import PrismaError
|
|||||||
import backend.api.features.admin.credit_admin_routes
|
import backend.api.features.admin.credit_admin_routes
|
||||||
import backend.api.features.admin.execution_analytics_routes
|
import backend.api.features.admin.execution_analytics_routes
|
||||||
import backend.api.features.admin.store_admin_routes
|
import backend.api.features.admin.store_admin_routes
|
||||||
|
import backend.api.features.admin.user_admin_routes
|
||||||
import backend.api.features.builder
|
import backend.api.features.builder
|
||||||
import backend.api.features.builder.routes
|
import backend.api.features.builder.routes
|
||||||
import backend.api.features.chat.routes as chat_routes
|
import backend.api.features.chat.routes as chat_routes
|
||||||
@@ -26,6 +27,7 @@ import backend.api.features.executions.review.routes
|
|||||||
import backend.api.features.library.db
|
import backend.api.features.library.db
|
||||||
import backend.api.features.library.model
|
import backend.api.features.library.model
|
||||||
import backend.api.features.library.routes
|
import backend.api.features.library.routes
|
||||||
|
import backend.api.features.mcp.routes as mcp_routes
|
||||||
import backend.api.features.oauth
|
import backend.api.features.oauth
|
||||||
import backend.api.features.otto.routes
|
import backend.api.features.otto.routes
|
||||||
import backend.api.features.postmark.postmark
|
import backend.api.features.postmark.postmark
|
||||||
@@ -40,9 +42,9 @@ import backend.data.user
|
|||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
from backend.api.features.chat.completion_consumer import (
|
from backend.api.features.library.exceptions import (
|
||||||
start_completion_consumer,
|
FolderAlreadyExistsError,
|
||||||
stop_completion_consumer,
|
FolderValidationError,
|
||||||
)
|
)
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
@@ -54,6 +56,7 @@ from backend.util.exceptions import (
|
|||||||
MissingConfigError,
|
MissingConfigError,
|
||||||
NotAuthorizedError,
|
NotAuthorizedError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
|
PreconditionFailed,
|
||||||
)
|
)
|
||||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||||
from backend.util.service import UnhealthyServiceError
|
from backend.util.service import UnhealthyServiceError
|
||||||
@@ -122,21 +125,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
# Start chat completion consumer for Redis Streams notifications
|
|
||||||
try:
|
|
||||||
await start_completion_consumer()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
|
||||||
|
|
||||||
with launch_darkly_context():
|
with launch_darkly_context():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Stop chat completion consumer
|
|
||||||
try:
|
|
||||||
await stop_completion_consumer()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await shutdown_cloud_storage_handler()
|
await shutdown_cloud_storage_handler()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -276,12 +267,17 @@ async def validation_error_handler(
|
|||||||
|
|
||||||
|
|
||||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||||
|
app.add_exception_handler(
|
||||||
|
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
||||||
|
)
|
||||||
|
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
||||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||||
|
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
|
||||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||||
|
|
||||||
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
||||||
@@ -316,6 +312,11 @@ app.include_router(
|
|||||||
tags=["v2", "admin"],
|
tags=["v2", "admin"],
|
||||||
prefix="/api/executions",
|
prefix="/api/executions",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
backend.api.features.admin.user_admin_routes.router,
|
||||||
|
tags=["v2", "admin"],
|
||||||
|
prefix="/api/users",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.executions.review.routes.router,
|
backend.api.features.executions.review.routes.router,
|
||||||
tags=["v2", "executions", "review"],
|
tags=["v2", "executions", "review"],
|
||||||
@@ -343,6 +344,11 @@ app.include_router(
|
|||||||
tags=["workspace"],
|
tags=["workspace"],
|
||||||
prefix="/api/workspace",
|
prefix="/api/workspace",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
mcp_routes.router,
|
||||||
|
tags=["v2", "mcp"],
|
||||||
|
prefix="/api/mcp",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user