mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
9 Commits
feat/githu
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11e6fca8c3 | ||
|
|
6e737e0b74 | ||
|
|
5ce002803d | ||
|
|
f8ad8484ee | ||
|
|
b6064d0155 | ||
|
|
76e0c96aa9 | ||
|
|
3364a8e415 | ||
|
|
9f4f2749a4 | ||
|
|
2b0f457985 |
@@ -1,79 +0,0 @@
|
|||||||
---
|
|
||||||
name: pr-address
|
|
||||||
description: Address PR review comments and loop until CI green and all comments resolved. TRIGGER when user asks to address comments, fix PR feedback, respond to reviewers, or babysit/monitor a PR.
|
|
||||||
user-invocable: true
|
|
||||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
|
||||||
metadata:
|
|
||||||
author: autogpt-team
|
|
||||||
version: "1.0.0"
|
|
||||||
---
|
|
||||||
|
|
||||||
# PR Address
|
|
||||||
|
|
||||||
## Find the PR
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
|
||||||
gh pr view {N}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Fetch comments (all sources)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews # top-level reviews
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments # inline review comments
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments # PR conversation comments
|
|
||||||
```
|
|
||||||
|
|
||||||
**Bots to watch for:**
|
|
||||||
- `autogpt-reviewer` — posts "Blockers", "Should Fix", "Nice to Have". Address ALL of them.
|
|
||||||
- `sentry[bot]` — bug predictions. Fix real bugs, explain false positives.
|
|
||||||
- `coderabbitai[bot]` — automated review. Address actionable items.
|
|
||||||
|
|
||||||
## For each unaddressed comment
|
|
||||||
|
|
||||||
Address comments **one at a time**: fix → commit → push → inline reply → next.
|
|
||||||
|
|
||||||
1. Read the referenced code, make the fix (or reply explaining why it's not needed)
|
|
||||||
2. Commit and push the fix
|
|
||||||
3. Reply **inline** (not as a new top-level comment) referencing the fixing commit — this is what resolves the conversation for bot reviewers (coderabbitai, sentry):
|
|
||||||
|
|
||||||
| Comment type | How to reply |
|
|
||||||
|---|---|
|
|
||||||
| Inline review (`pulls/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments/{ID}/replies -f body="Fixed in <commit-sha>: <description>"` |
|
|
||||||
| Conversation (`issues/{N}/comments`) | `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments -f body="Fixed in <commit-sha>: <description>"` |
|
|
||||||
|
|
||||||
## Format and commit
|
|
||||||
|
|
||||||
After fixing, format the changed code:
|
|
||||||
|
|
||||||
- **Backend** (from `autogpt_platform/backend/`): `poetry run format`
|
|
||||||
- **Frontend** (from `autogpt_platform/frontend/`): `pnpm format && pnpm lint && pnpm types`
|
|
||||||
|
|
||||||
If API routes changed, regenerate the frontend client:
|
|
||||||
```bash
|
|
||||||
cd autogpt_platform/backend && poetry run rest &
|
|
||||||
REST_PID=$!
|
|
||||||
trap "kill $REST_PID 2>/dev/null" EXIT
|
|
||||||
WAIT=0; until curl -sf http://localhost:8006/health > /dev/null 2>&1; do sleep 1; WAIT=$((WAIT+1)); [ $WAIT -ge 60 ] && echo "Timed out" && exit 1; done
|
|
||||||
cd ../frontend && pnpm generate:api:force
|
|
||||||
kill $REST_PID 2>/dev/null; trap - EXIT
|
|
||||||
```
|
|
||||||
Never manually edit files in `src/app/api/__generated__/`.
|
|
||||||
|
|
||||||
Then commit and **push immediately** — never batch commits without pushing.
|
|
||||||
|
|
||||||
For backend commits in worktrees: `poetry run git commit` (pre-commit hooks).
|
|
||||||
|
|
||||||
## The loop
|
|
||||||
|
|
||||||
```text
|
|
||||||
address comments → format → commit → push
|
|
||||||
→ re-check comments → fix new ones → push
|
|
||||||
→ wait for CI → re-check comments after CI settles
|
|
||||||
→ repeat until: all comments addressed AND CI green AND no new comments arriving
|
|
||||||
```
|
|
||||||
|
|
||||||
While CI runs, stay productive: run local tests, address remaining comments.
|
|
||||||
|
|
||||||
**The loop ends when:** CI fully green + all comments addressed + no new comments since CI settled.
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
---
|
|
||||||
name: pr-review
|
|
||||||
description: Review a PR for correctness, security, code quality, and testing issues. TRIGGER when user asks to review a PR, check PR quality, or give feedback on a PR.
|
|
||||||
user-invocable: true
|
|
||||||
args: "[PR number or URL] — if omitted, finds PR for current branch."
|
|
||||||
metadata:
|
|
||||||
author: autogpt-team
|
|
||||||
version: "1.0.0"
|
|
||||||
---
|
|
||||||
|
|
||||||
# PR Review
|
|
||||||
|
|
||||||
## Find the PR
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh pr list --head $(git branch --show-current) --repo Significant-Gravitas/AutoGPT
|
|
||||||
gh pr view {N}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Read the diff
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh pr diff {N}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Fetch existing review comments
|
|
||||||
|
|
||||||
Before posting anything, fetch existing inline comments to avoid duplicates:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews
|
|
||||||
```
|
|
||||||
|
|
||||||
## What to check
|
|
||||||
|
|
||||||
**Correctness:** logic errors, off-by-one, missing edge cases, race conditions (TOCTOU in file access, credit charging), error handling gaps, async correctness (missing `await`, unclosed resources).
|
|
||||||
|
|
||||||
**Security:** input validation at boundaries, no injection (command, XSS, SQL), secrets not logged, file paths sanitized (`os.path.basename()` in error messages).
|
|
||||||
|
|
||||||
**Code quality:** apply rules from backend/frontend CLAUDE.md files.
|
|
||||||
|
|
||||||
**Architecture:** DRY, single responsibility, modular functions. `Security()` vs `Depends()` for FastAPI auth. `data:` for SSE events, `: comment` for heartbeats. `transaction=True` for Redis pipelines.
|
|
||||||
|
|
||||||
**Testing:** edge cases covered, colocated `*_test.py` (backend) / `__tests__/` (frontend), mocks target where symbol is **used** not defined, `AsyncMock` for async.
|
|
||||||
|
|
||||||
## Output format
|
|
||||||
|
|
||||||
Every comment **must** be prefixed with `🤖` and a criticality badge:
|
|
||||||
|
|
||||||
| Tier | Badge | Meaning |
|
|
||||||
|---|---|---|
|
|
||||||
| Blocker | `🔴 **Blocker**` | Must fix before merge |
|
|
||||||
| Should Fix | `🟠 **Should Fix**` | Important improvement |
|
|
||||||
| Nice to Have | `🟡 **Nice to Have**` | Minor suggestion |
|
|
||||||
| Nit | `🔵 **Nit**` | Style / wording |
|
|
||||||
|
|
||||||
Example: `🤖 🔴 **Blocker**: Missing error handling for X — suggest wrapping in try/except.`
|
|
||||||
|
|
||||||
## Post inline comments
|
|
||||||
|
|
||||||
For each finding, post an inline comment on the PR (do not just write a local report):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Get the latest commit SHA for the PR
|
|
||||||
COMMIT_SHA=$(gh api repos/Significant-Gravitas/AutoGPT/pulls/{N} --jq '.head.sha')
|
|
||||||
|
|
||||||
# Post an inline comment on a specific file/line
|
|
||||||
gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments \
|
|
||||||
-f body="🤖 🔴 **Blocker**: <description>" \
|
|
||||||
-f commit_id="$COMMIT_SHA" \
|
|
||||||
-f path="<file path>" \
|
|
||||||
-F line=<line number>
|
|
||||||
```
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
---
|
|
||||||
name: worktree
|
|
||||||
description: Set up a new git worktree for parallel development. Creates the worktree, copies .env files, installs dependencies, and generates Prisma client. TRIGGER when user asks to set up a worktree, work on a branch in isolation, or needs a separate environment for a branch or PR.
|
|
||||||
user-invocable: true
|
|
||||||
args: "[name] — optional worktree name (e.g., 'AutoGPT7'). If omitted, uses next available AutoGPT<N>."
|
|
||||||
metadata:
|
|
||||||
author: autogpt-team
|
|
||||||
version: "3.0.0"
|
|
||||||
---
|
|
||||||
|
|
||||||
# Worktree Setup
|
|
||||||
|
|
||||||
## Create the worktree
|
|
||||||
|
|
||||||
Derive paths from the git toplevel. If a name is provided as argument, use it. Otherwise, check `git worktree list` and pick the next `AutoGPT<N>`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ROOT=$(git rev-parse --show-toplevel)
|
|
||||||
PARENT=$(dirname "$ROOT")
|
|
||||||
|
|
||||||
# From an existing branch
|
|
||||||
git worktree add "$PARENT/<NAME>" <branch-name>
|
|
||||||
|
|
||||||
# From a new branch off dev
|
|
||||||
git worktree add -b <new-branch> "$PARENT/<NAME>" dev
|
|
||||||
```
|
|
||||||
|
|
||||||
## Copy environment files
|
|
||||||
|
|
||||||
Copy `.env` from the root worktree. Falls back to `.env.default` if `.env` doesn't exist.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ROOT=$(git rev-parse --show-toplevel)
|
|
||||||
TARGET="$(dirname "$ROOT")/<NAME>"
|
|
||||||
|
|
||||||
for envpath in autogpt_platform/backend autogpt_platform/frontend autogpt_platform; do
|
|
||||||
if [ -f "$ROOT/$envpath/.env" ]; then
|
|
||||||
cp "$ROOT/$envpath/.env" "$TARGET/$envpath/.env"
|
|
||||||
elif [ -f "$ROOT/$envpath/.env.default" ]; then
|
|
||||||
cp "$ROOT/$envpath/.env.default" "$TARGET/$envpath/.env"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
```
|
|
||||||
|
|
||||||
## Install dependencies
|
|
||||||
|
|
||||||
```bash
|
|
||||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
|
||||||
cd "$TARGET/autogpt_platform/autogpt_libs" && poetry install
|
|
||||||
cd "$TARGET/autogpt_platform/backend" && poetry install && poetry run prisma generate
|
|
||||||
cd "$TARGET/autogpt_platform/frontend" && pnpm install
|
|
||||||
```
|
|
||||||
|
|
||||||
Replace `<NAME>` with the actual worktree name (e.g., `AutoGPT7`).
|
|
||||||
|
|
||||||
## Running the app (optional)
|
|
||||||
|
|
||||||
Backend uses ports: 8001, 8002, 8003, 8005, 8006, 8007, 8008. Free them first if needed:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
TARGET="$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
|
||||||
for port in 8001 8002 8003 8005 8006 8007 8008; do
|
|
||||||
lsof -ti :$port | xargs kill -9 2>/dev/null || true
|
|
||||||
done
|
|
||||||
cd "$TARGET/autogpt_platform/backend" && poetry run app
|
|
||||||
```
|
|
||||||
|
|
||||||
## CoPilot testing
|
|
||||||
|
|
||||||
SDK mode spawns a Claude subprocess — won't work inside Claude Code. Set `CHAT_USE_CLAUDE_AGENT_SDK=false` in `backend/.env` to use baseline mode.
|
|
||||||
|
|
||||||
## Cleanup
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Replace <NAME> with the actual worktree name (e.g., AutoGPT7)
|
|
||||||
git worktree remove "$(dirname "$(git rev-parse --show-toplevel)")/<NAME>"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Alternative: Branchlet (optional)
|
|
||||||
|
|
||||||
If [branchlet](https://www.npmjs.com/package/branchlet) is installed:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
branchlet create -n <name> -s <source-branch> -b <new-branch>
|
|
||||||
```
|
|
||||||
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -149,7 +149,7 @@ jobs:
|
|||||||
driver-opts: network=host
|
driver-opts: network=host
|
||||||
|
|
||||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||||
uses: crazy-max/ghaction-github-runtime@v4
|
uses: crazy-max/ghaction-github-runtime@v3
|
||||||
|
|
||||||
- name: Set up Platform - Build Docker images (with cache)
|
- name: Set up Platform - Build Docker images (with cache)
|
||||||
working-directory: autogpt_platform
|
working-directory: autogpt_platform
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -180,6 +180,4 @@ autogpt_platform/backend/settings.py
|
|||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
.next
|
.next
|
||||||
# Implementation plans (generated by AI agents)
|
|
||||||
plans/
|
|
||||||
@@ -1,10 +1,3 @@
|
|||||||
default_install_hook_types:
|
|
||||||
- pre-commit
|
|
||||||
- pre-push
|
|
||||||
- post-checkout
|
|
||||||
|
|
||||||
default_stages: [pre-commit]
|
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.4.0
|
rev: v4.4.0
|
||||||
@@ -24,7 +17,6 @@ repos:
|
|||||||
name: Detect secrets
|
name: Detect secrets
|
||||||
description: Detects high entropy strings that are likely to be passwords.
|
description: Detects high entropy strings that are likely to be passwords.
|
||||||
files: ^autogpt_platform/
|
files: ^autogpt_platform/
|
||||||
exclude: pnpm-lock\.yaml$
|
|
||||||
stages: [pre-push]
|
stages: [pre-push]
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
@@ -34,106 +26,49 @@ repos:
|
|||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - AutoGPT Platform - Backend
|
name: Check & Install dependencies - AutoGPT Platform - Backend
|
||||||
alias: poetry-install-platform-backend
|
alias: poetry-install-platform-backend
|
||||||
|
entry: poetry -C autogpt_platform/backend install
|
||||||
# include autogpt_libs source (since it's a path dependency)
|
# include autogpt_libs source (since it's a path dependency)
|
||||||
entry: >
|
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
|
||||||
bash -c '
|
types: [file]
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C autogpt_platform/backend install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - AutoGPT Platform - Libs
|
name: Check & Install dependencies - AutoGPT Platform - Libs
|
||||||
alias: poetry-install-platform-libs
|
alias: poetry-install-platform-libs
|
||||||
entry: >
|
entry: poetry -C autogpt_platform/autogpt_libs install
|
||||||
bash -c '
|
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
types: [file]
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C autogpt_platform/autogpt_libs install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: pnpm-install
|
|
||||||
name: Check & Install dependencies - AutoGPT Platform - Frontend
|
|
||||||
alias: pnpm-install-platform-frontend
|
|
||||||
entry: >
|
|
||||||
bash -c '
|
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
|
|
||||||
pnpm --prefix autogpt_platform/frontend install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
|
||||||
pass_filenames: false
|
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic - AutoGPT
|
name: Check & Install dependencies - Classic - AutoGPT
|
||||||
alias: poetry-install-classic-autogpt
|
alias: poetry-install-classic-autogpt
|
||||||
entry: >
|
entry: poetry -C classic/original_autogpt install
|
||||||
bash -c '
|
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C classic/original_autogpt install
|
|
||||||
'
|
|
||||||
# include forge source (since it's a path dependency)
|
# include forge source (since it's a path dependency)
|
||||||
always_run: true
|
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||||
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic - Forge
|
name: Check & Install dependencies - Classic - Forge
|
||||||
alias: poetry-install-classic-forge
|
alias: poetry-install-classic-forge
|
||||||
entry: >
|
entry: poetry -C classic/forge install
|
||||||
bash -c '
|
files: ^classic/forge/poetry\.lock$
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
types: [file]
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C classic/forge install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic - Benchmark
|
name: Check & Install dependencies - Classic - Benchmark
|
||||||
alias: poetry-install-classic-benchmark
|
alias: poetry-install-classic-benchmark
|
||||||
entry: >
|
entry: poetry -C classic/benchmark install
|
||||||
bash -c '
|
files: ^classic/benchmark/poetry\.lock$
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
types: [file]
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
|
|
||||||
poetry -C classic/benchmark install
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
# For proper type checking, Prisma client must be up-to-date.
|
# For proper type checking, Prisma client must be up-to-date.
|
||||||
@@ -141,54 +76,12 @@ repos:
|
|||||||
- id: prisma-generate
|
- id: prisma-generate
|
||||||
name: Prisma Generate - AutoGPT Platform - Backend
|
name: Prisma Generate - AutoGPT Platform - Backend
|
||||||
alias: prisma-generate-platform-backend
|
alias: prisma-generate-platform-backend
|
||||||
entry: >
|
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
|
||||||
bash -c '
|
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
|
|
||||||
else
|
|
||||||
git diff --cached --name-only
|
|
||||||
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
|
|
||||||
cd autogpt_platform/backend
|
|
||||||
&& poetry run prisma generate
|
|
||||||
&& poetry run gen-prisma-stub
|
|
||||||
'
|
|
||||||
# include everything that triggers poetry install + the prisma schema
|
# include everything that triggers poetry install + the prisma schema
|
||||||
always_run: true
|
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
|
||||||
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- id: export-api-schema
|
|
||||||
name: Export API schema - AutoGPT Platform - Backend -> Frontend
|
|
||||||
alias: export-api-schema-platform
|
|
||||||
entry: >
|
|
||||||
bash -c '
|
|
||||||
cd autogpt_platform/backend
|
|
||||||
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
|
|
||||||
&& cd ../frontend
|
|
||||||
&& pnpm prettier --write ./src/app/api/openapi.json
|
|
||||||
'
|
|
||||||
files: ^autogpt_platform/backend/
|
|
||||||
language: system
|
|
||||||
pass_filenames: false
|
|
||||||
|
|
||||||
- id: generate-api-client
|
|
||||||
name: Generate API client - AutoGPT Platform - Frontend
|
|
||||||
alias: generate-api-client-platform-frontend
|
|
||||||
entry: >
|
|
||||||
bash -c '
|
|
||||||
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
|
|
||||||
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
|
|
||||||
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
|
|
||||||
else
|
|
||||||
git diff --quiet HEAD -- "$SCHEMA" && exit 0
|
|
||||||
fi;
|
|
||||||
cd autogpt_platform/frontend && pnpm generate:api
|
|
||||||
'
|
|
||||||
always_run: true
|
|
||||||
language: system
|
|
||||||
pass_filenames: false
|
|
||||||
stages: [pre-commit, post-checkout]
|
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.7.2
|
rev: v0.7.2
|
||||||
|
|||||||
3
autogpt_platform/.gitignore
vendored
3
autogpt_platform/.gitignore
vendored
@@ -1,3 +1,2 @@
|
|||||||
*.ignore.*
|
*.ignore.*
|
||||||
*.ign.*
|
*.ign.*
|
||||||
.application.logs
|
|
||||||
@@ -60,12 +60,9 @@ AutoGPT Platform is a monorepo containing:
|
|||||||
|
|
||||||
### Reviewing/Revising Pull Requests
|
### Reviewing/Revising Pull Requests
|
||||||
|
|
||||||
Use `/pr-review` to review a PR or `/pr-address` to address comments.
|
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||||
|
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||||
When fetching comments manually:
|
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/reviews` — top-level reviews
|
|
||||||
- `gh api repos/Significant-Gravitas/AutoGPT/pulls/{N}/comments` — inline review comments
|
|
||||||
- `gh api repos/Significant-Gravitas/AutoGPT/issues/{N}/comments` — PR conversation comments
|
|
||||||
|
|
||||||
### Conventional Commits
|
### Conventional Commits
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.auth_activities
|
|
||||||
-- Looker source alias: ds49 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Tracks authentication events (login, logout, SSO, password
|
|
||||||
-- reset, etc.) from Supabase's internal audit log.
|
|
||||||
-- Useful for monitoring sign-in patterns and detecting anomalies.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.audit_log_entries — Supabase internal auth event log
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- created_at TIMESTAMPTZ When the auth event occurred
|
|
||||||
-- actor_id TEXT User ID who triggered the event
|
|
||||||
-- actor_via_sso TEXT Whether the action was via SSO ('true'/'false')
|
|
||||||
-- action TEXT Event type (e.g. 'login', 'logout', 'token_refreshed')
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Rolling 90 days from current date
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Daily login counts
|
|
||||||
-- SELECT DATE_TRUNC('day', created_at) AS day, COUNT(*) AS logins
|
|
||||||
-- FROM analytics.auth_activities
|
|
||||||
-- WHERE action = 'login'
|
|
||||||
-- GROUP BY 1 ORDER BY 1;
|
|
||||||
--
|
|
||||||
-- -- SSO vs password login breakdown
|
|
||||||
-- SELECT actor_via_sso, COUNT(*) FROM analytics.auth_activities
|
|
||||||
-- WHERE action = 'login' GROUP BY 1;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
created_at,
|
|
||||||
payload->>'actor_id' AS actor_id,
|
|
||||||
payload->>'actor_via_sso' AS actor_via_sso,
|
|
||||||
payload->>'action' AS action
|
|
||||||
FROM auth.audit_log_entries
|
|
||||||
WHERE created_at >= NOW() - INTERVAL '90 days'
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.graph_execution
|
|
||||||
-- Looker source alias: ds16 | Charts: 21
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per agent graph execution (last 90 days).
|
|
||||||
-- Unpacks the JSONB stats column into individual numeric columns
|
|
||||||
-- and normalises the executionStatus — runs that failed due to
|
|
||||||
-- insufficient credits are reclassified as 'NO_CREDITS' for
|
|
||||||
-- easier filtering. Error messages are scrubbed of IDs and URLs
|
|
||||||
-- to allow safe grouping.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentGraphExecution — Execution records
|
|
||||||
-- platform.AgentGraph — Agent graph metadata (for name)
|
|
||||||
-- platform.LibraryAgent — To flag possibly-AI (safe-mode) agents
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- id TEXT Execution UUID
|
|
||||||
-- agentGraphId TEXT Agent graph UUID
|
|
||||||
-- agentGraphVersion INT Graph version number
|
|
||||||
-- executionStatus TEXT COMPLETED | FAILED | NO_CREDITS | RUNNING | QUEUED | TERMINATED
|
|
||||||
-- createdAt TIMESTAMPTZ When the execution was queued
|
|
||||||
-- updatedAt TIMESTAMPTZ Last status update time
|
|
||||||
-- userId TEXT Owner user UUID
|
|
||||||
-- agentGraphName TEXT Human-readable agent name
|
|
||||||
-- cputime DECIMAL Total CPU seconds consumed
|
|
||||||
-- walltime DECIMAL Total wall-clock seconds
|
|
||||||
-- node_count DECIMAL Number of nodes in the graph
|
|
||||||
-- nodes_cputime DECIMAL CPU time across all nodes
|
|
||||||
-- nodes_walltime DECIMAL Wall time across all nodes
|
|
||||||
-- execution_cost DECIMAL Credit cost of this execution
|
|
||||||
-- correctness_score FLOAT AI correctness score (if available)
|
|
||||||
-- possibly_ai BOOLEAN True if agent has sensitive_action_safe_mode enabled
|
|
||||||
-- groupedErrorMessage TEXT Scrubbed error string (IDs/URLs replaced with wildcards)
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Daily execution counts by status
|
|
||||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day, "executionStatus", COUNT(*)
|
|
||||||
-- FROM analytics.graph_execution
|
|
||||||
-- GROUP BY 1, 2 ORDER BY 1;
|
|
||||||
--
|
|
||||||
-- -- Average cost per execution by agent
|
|
||||||
-- SELECT "agentGraphName", AVG("execution_cost") AS avg_cost, COUNT(*) AS runs
|
|
||||||
-- FROM analytics.graph_execution
|
|
||||||
-- WHERE "executionStatus" = 'COMPLETED'
|
|
||||||
-- GROUP BY 1 ORDER BY avg_cost DESC;
|
|
||||||
--
|
|
||||||
-- -- Top error messages
|
|
||||||
-- SELECT "groupedErrorMessage", COUNT(*) AS occurrences
|
|
||||||
-- FROM analytics.graph_execution
|
|
||||||
-- WHERE "executionStatus" = 'FAILED'
|
|
||||||
-- GROUP BY 1 ORDER BY 2 DESC LIMIT 20;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
ge."id" AS id,
|
|
||||||
ge."agentGraphId" AS agentGraphId,
|
|
||||||
ge."agentGraphVersion" AS agentGraphVersion,
|
|
||||||
CASE
|
|
||||||
WHEN jsonb_exists(ge."stats"::jsonb, 'error')
|
|
||||||
AND (
|
|
||||||
(ge."stats"::jsonb->>'error') ILIKE '%insufficient balance%'
|
|
||||||
OR (ge."stats"::jsonb->>'error') ILIKE '%you have no credits left%'
|
|
||||||
)
|
|
||||||
THEN 'NO_CREDITS'
|
|
||||||
ELSE CAST(ge."executionStatus" AS TEXT)
|
|
||||||
END AS executionStatus,
|
|
||||||
ge."createdAt" AS createdAt,
|
|
||||||
ge."updatedAt" AS updatedAt,
|
|
||||||
ge."userId" AS userId,
|
|
||||||
g."name" AS agentGraphName,
|
|
||||||
(ge."stats"::jsonb->>'cputime')::decimal AS cputime,
|
|
||||||
(ge."stats"::jsonb->>'walltime')::decimal AS walltime,
|
|
||||||
(ge."stats"::jsonb->>'node_count')::decimal AS node_count,
|
|
||||||
(ge."stats"::jsonb->>'nodes_cputime')::decimal AS nodes_cputime,
|
|
||||||
(ge."stats"::jsonb->>'nodes_walltime')::decimal AS nodes_walltime,
|
|
||||||
(ge."stats"::jsonb->>'cost')::decimal AS execution_cost,
|
|
||||||
(ge."stats"::jsonb->>'correctness_score')::float AS correctness_score,
|
|
||||||
COALESCE(la.possibly_ai, FALSE) AS possibly_ai,
|
|
||||||
REGEXP_REPLACE(
|
|
||||||
REGEXP_REPLACE(
|
|
||||||
TRIM(BOTH '"' FROM ge."stats"::jsonb->>'error'),
|
|
||||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
|
||||||
'\1\2/...', 'gi'
|
|
||||||
),
|
|
||||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
|
||||||
) AS groupedErrorMessage
|
|
||||||
FROM platform."AgentGraphExecution" ge
|
|
||||||
LEFT JOIN platform."AgentGraph" g
|
|
||||||
ON ge."agentGraphId" = g."id"
|
|
||||||
AND ge."agentGraphVersion" = g."version"
|
|
||||||
LEFT JOIN (
|
|
||||||
SELECT DISTINCT ON ("userId", "agentGraphId")
|
|
||||||
"userId", "agentGraphId",
|
|
||||||
("settings"::jsonb->>'sensitive_action_safe_mode')::boolean AS possibly_ai
|
|
||||||
FROM platform."LibraryAgent"
|
|
||||||
WHERE "isDeleted" = FALSE
|
|
||||||
AND "isArchived" = FALSE
|
|
||||||
ORDER BY "userId", "agentGraphId", "agentGraphVersion" DESC
|
|
||||||
) la ON la."userId" = ge."userId" AND la."agentGraphId" = ge."agentGraphId"
|
|
||||||
WHERE ge."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.node_block_execution
|
|
||||||
-- Looker source alias: ds14 | Charts: 11
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per node (block) execution (last 90 days).
|
|
||||||
-- Unpacks stats JSONB and joins to identify which block type
|
|
||||||
-- was run. For failed nodes, joins the error output and
|
|
||||||
-- scrubs it for safe grouping.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentNodeExecution — Node execution records
|
|
||||||
-- platform.AgentNode — Node → block mapping
|
|
||||||
-- platform.AgentBlock — Block name/ID
|
|
||||||
-- platform.AgentNodeExecutionInputOutput — Error output values
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- id TEXT Node execution UUID
|
|
||||||
-- agentGraphExecutionId TEXT Parent graph execution UUID
|
|
||||||
-- agentNodeId TEXT Node UUID within the graph
|
|
||||||
-- executionStatus TEXT COMPLETED | FAILED | QUEUED | RUNNING | TERMINATED
|
|
||||||
-- addedTime TIMESTAMPTZ When the node was queued
|
|
||||||
-- queuedTime TIMESTAMPTZ When it entered the queue
|
|
||||||
-- startedTime TIMESTAMPTZ When execution started
|
|
||||||
-- endedTime TIMESTAMPTZ When execution finished
|
|
||||||
-- inputSize BIGINT Input payload size in bytes
|
|
||||||
-- outputSize BIGINT Output payload size in bytes
|
|
||||||
-- walltime NUMERIC Wall-clock seconds for this node
|
|
||||||
-- cputime NUMERIC CPU seconds for this node
|
|
||||||
-- llmRetryCount INT Number of LLM retries
|
|
||||||
-- llmCallCount INT Number of LLM API calls made
|
|
||||||
-- inputTokenCount BIGINT LLM input tokens consumed
|
|
||||||
-- outputTokenCount BIGINT LLM output tokens produced
|
|
||||||
-- blockName TEXT Human-readable block name (e.g. 'OpenAIBlock')
|
|
||||||
-- blockId TEXT Block UUID
|
|
||||||
-- groupedErrorMessage TEXT Scrubbed error (IDs/URLs wildcarded)
|
|
||||||
-- errorMessage TEXT Raw error output (only set when FAILED)
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Rolling 90 days (addedTime > CURRENT_DATE - 90 days)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Most-used blocks by execution count
|
|
||||||
-- SELECT "blockName", COUNT(*) AS executions,
|
|
||||||
-- COUNT(*) FILTER (WHERE "executionStatus"='FAILED') AS failures
|
|
||||||
-- FROM analytics.node_block_execution
|
|
||||||
-- GROUP BY 1 ORDER BY executions DESC LIMIT 20;
|
|
||||||
--
|
|
||||||
-- -- Average LLM token usage per block
|
|
||||||
-- SELECT "blockName",
|
|
||||||
-- AVG("inputTokenCount") AS avg_input_tokens,
|
|
||||||
-- AVG("outputTokenCount") AS avg_output_tokens
|
|
||||||
-- FROM analytics.node_block_execution
|
|
||||||
-- WHERE "llmCallCount" > 0
|
|
||||||
-- GROUP BY 1 ORDER BY avg_input_tokens DESC;
|
|
||||||
--
|
|
||||||
-- -- Top failure reasons
|
|
||||||
-- SELECT "blockName", "groupedErrorMessage", COUNT(*) AS count
|
|
||||||
-- FROM analytics.node_block_execution
|
|
||||||
-- WHERE "executionStatus" = 'FAILED'
|
|
||||||
-- GROUP BY 1, 2 ORDER BY count DESC LIMIT 20;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
ne."id" AS id,
|
|
||||||
ne."agentGraphExecutionId" AS agentGraphExecutionId,
|
|
||||||
ne."agentNodeId" AS agentNodeId,
|
|
||||||
CAST(ne."executionStatus" AS TEXT) AS executionStatus,
|
|
||||||
ne."addedTime" AS addedTime,
|
|
||||||
ne."queuedTime" AS queuedTime,
|
|
||||||
ne."startedTime" AS startedTime,
|
|
||||||
ne."endedTime" AS endedTime,
|
|
||||||
(ne."stats"::jsonb->>'input_size')::bigint AS inputSize,
|
|
||||||
(ne."stats"::jsonb->>'output_size')::bigint AS outputSize,
|
|
||||||
(ne."stats"::jsonb->>'walltime')::numeric AS walltime,
|
|
||||||
(ne."stats"::jsonb->>'cputime')::numeric AS cputime,
|
|
||||||
(ne."stats"::jsonb->>'llm_retry_count')::int AS llmRetryCount,
|
|
||||||
(ne."stats"::jsonb->>'llm_call_count')::int AS llmCallCount,
|
|
||||||
(ne."stats"::jsonb->>'input_token_count')::bigint AS inputTokenCount,
|
|
||||||
(ne."stats"::jsonb->>'output_token_count')::bigint AS outputTokenCount,
|
|
||||||
b."name" AS blockName,
|
|
||||||
b."id" AS blockId,
|
|
||||||
REGEXP_REPLACE(
|
|
||||||
REGEXP_REPLACE(
|
|
||||||
TRIM(BOTH '"' FROM eio."data"::text),
|
|
||||||
'(https?://)([A-Za-z0-9.-]+)(:[0-9]+)?(/[^\s]*)?',
|
|
||||||
'\1\2/...', 'gi'
|
|
||||||
),
|
|
||||||
'[a-zA-Z0-9_:-]*\d[a-zA-Z0-9_:-]*', '*', 'g'
|
|
||||||
) AS groupedErrorMessage,
|
|
||||||
eio."data" AS errorMessage
|
|
||||||
FROM platform."AgentNodeExecution" ne
|
|
||||||
LEFT JOIN platform."AgentNode" nd
|
|
||||||
ON ne."agentNodeId" = nd."id"
|
|
||||||
LEFT JOIN platform."AgentBlock" b
|
|
||||||
ON nd."agentBlockId" = b."id"
|
|
||||||
LEFT JOIN platform."AgentNodeExecutionInputOutput" eio
|
|
||||||
ON eio."referencedByOutputExecId" = ne."id"
|
|
||||||
AND eio."name" = 'error'
|
|
||||||
AND ne."executionStatus" = 'FAILED'
|
|
||||||
WHERE ne."addedTime" > CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_agent
|
|
||||||
-- Looker source alias: ds35 | Charts: 2
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Weekly cohort retention broken down per individual agent.
|
|
||||||
-- Cohort = week of a user's first use of THAT specific agent.
|
|
||||||
-- Tells you which agents keep users coming back vs. one-shot
|
|
||||||
-- use. Only includes cohorts from the last 180 days.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentGraphExecution — Execution records (user × agent × time)
|
|
||||||
-- platform.AgentGraph — Agent names
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- agent_id TEXT Agent graph UUID
|
|
||||||
-- agent_label TEXT 'AgentName [first8chars]'
|
|
||||||
-- agent_label_n TEXT 'AgentName [first8chars] (n=total_users)'
|
|
||||||
-- cohort_week_start DATE Week users first ran this agent
|
|
||||||
-- cohort_label TEXT ISO week label
|
|
||||||
-- cohort_label_n TEXT ISO week label with cohort size
|
|
||||||
-- user_lifetime_week INT Weeks since first use of this agent
|
|
||||||
-- cohort_users BIGINT Users in this cohort for this agent
|
|
||||||
-- active_users BIGINT Users who ran the agent again in week k
|
|
||||||
-- retention_rate FLOAT active_users / cohort_users
|
|
||||||
-- cohort_users_w0 BIGINT cohort_users only at week 0 (safe to SUM)
|
|
||||||
-- agent_total_users BIGINT Total users across all cohorts for this agent
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Best-retained agents at week 2
|
|
||||||
-- SELECT agent_label, AVG(retention_rate) AS w2_retention
|
|
||||||
-- FROM analytics.retention_agent
|
|
||||||
-- WHERE user_lifetime_week = 2 AND cohort_users >= 10
|
|
||||||
-- GROUP BY 1 ORDER BY w2_retention DESC LIMIT 10;
|
|
||||||
--
|
|
||||||
-- -- Agents with most unique users
|
|
||||||
-- SELECT DISTINCT agent_label, agent_total_users
|
|
||||||
-- FROM analytics.retention_agent
|
|
||||||
-- ORDER BY agent_total_users DESC LIMIT 20;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
|
||||||
events AS (
|
|
||||||
SELECT e."userId"::text AS user_id, e."agentGraphId" AS agent_id,
|
|
||||||
e."createdAt"::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
|
||||||
FROM platform."AgentGraphExecution" e
|
|
||||||
),
|
|
||||||
first_use AS (
|
|
||||||
SELECT user_id, agent_id, MIN(created_at) AS first_use_at,
|
|
||||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
|
||||||
FROM events GROUP BY 1,2
|
|
||||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
|
||||||
),
|
|
||||||
activity_weeks AS (SELECT DISTINCT user_id, agent_id, week_start FROM events),
|
|
||||||
user_week_age AS (
|
|
||||||
SELECT aw.user_id, aw.agent_id, fu.cohort_week_start,
|
|
||||||
((aw.week_start - DATE_TRUNC('week',fu.first_use_at)::date)/7)::int AS user_lifetime_week
|
|
||||||
FROM activity_weeks aw JOIN first_use fu USING (user_id, agent_id)
|
|
||||||
WHERE aw.week_start >= DATE_TRUNC('week',fu.first_use_at)::date
|
|
||||||
),
|
|
||||||
active_counts AS (
|
|
||||||
SELECT agent_id, cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users
|
|
||||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2,3
|
|
||||||
),
|
|
||||||
cohort_sizes AS (
|
|
||||||
SELECT agent_id, cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_use GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.agent_id, cs.cohort_week_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_weeks FROM params),
|
|
||||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.agent_id, cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
|
||||||
),
|
|
||||||
agent_names AS (SELECT DISTINCT ON (g."id") g."id" AS agent_id, g."name" AS agent_name FROM platform."AgentGraph" g ORDER BY g."id", g."version" DESC),
|
|
||||||
agent_total_users AS (SELECT agent_id, SUM(cohort_users) AS agent_total_users FROM cohort_sizes GROUP BY 1)
|
|
||||||
SELECT
|
|
||||||
g.agent_id,
|
|
||||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||']' AS agent_label,
|
|
||||||
COALESCE(an.agent_name,'(unnamed)')||' ['||LEFT(g.agent_id::text,8)||'] (n='||COALESCE(atu.agent_total_users,0)||')' AS agent_label_n,
|
|
||||||
g.cohort_week_start,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_week, g.cohort_users,
|
|
||||||
COALESCE(ac.active_users,0) AS active_users,
|
|
||||||
COALESCE(ac.active_users,0)::float / NULLIF(g.cohort_users,0) AS retention_rate,
|
|
||||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0,
|
|
||||||
COALESCE(atu.agent_total_users,0) AS agent_total_users
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN active_counts ac ON ac.agent_id=g.agent_id AND ac.cohort_week_start=g.cohort_week_start AND ac.user_lifetime_week=g.user_lifetime_week
|
|
||||||
LEFT JOIN agent_names an ON an.agent_id=g.agent_id
|
|
||||||
LEFT JOIN agent_total_users atu ON atu.agent_id=g.agent_id
|
|
||||||
ORDER BY agent_label, g.cohort_week_start, g.user_lifetime_week;
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_execution_daily
|
|
||||||
-- Looker source alias: ds111 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Daily cohort retention based on agent executions.
|
|
||||||
-- Cohort anchor = day of user's FIRST ever execution.
|
|
||||||
-- Only includes cohorts from the last 90 days, up to day 30.
|
|
||||||
-- Great for early engagement analysis (did users run another
|
|
||||||
-- agent the next day?).
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentGraphExecution — Execution records
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- Same pattern as retention_login_daily.
|
|
||||||
-- cohort_day_start = day of first execution (not first login)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Day-3 execution retention
|
|
||||||
-- SELECT cohort_label, retention_rate_bounded AS d3_retention
|
|
||||||
-- FROM analytics.retention_execution_daily
|
|
||||||
-- WHERE user_lifetime_day = 3 ORDER BY cohort_day_start;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days') AS cohort_start),
|
|
||||||
events AS (
|
|
||||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('day', e."createdAt")::date AS day_start
|
|
||||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
|
||||||
),
|
|
||||||
first_exec AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
|
||||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
|
||||||
),
|
|
||||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
|
||||||
user_day_age AS (
|
|
||||||
SELECT ad.user_id, fe.cohort_day_start,
|
|
||||||
(ad.day_start - DATE_TRUNC('day',fe.first_exec_at)::date)::int AS user_lifetime_day
|
|
||||||
FROM activity_days ad JOIN first_exec fe USING (user_id)
|
|
||||||
WHERE ad.day_start >= DATE_TRUNC('day',fe.first_exec_at)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_day_start,
|
|
||||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_day, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
|
||||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_execution_weekly
|
|
||||||
-- Looker source alias: ds92 | Charts: 2
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Weekly cohort retention based on agent executions.
|
|
||||||
-- Cohort anchor = week of user's FIRST ever agent execution
|
|
||||||
-- (not first login). Only includes cohorts from the last 180 days.
|
|
||||||
-- Useful when you care about product engagement, not just visits.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.AgentGraphExecution — Execution records
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- Same pattern as retention_login_weekly.
|
|
||||||
-- cohort_week_start = week of first execution (not first login)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Week-2 execution retention
|
|
||||||
-- SELECT cohort_label, retention_rate_bounded
|
|
||||||
-- FROM analytics.retention_execution_weekly
|
|
||||||
-- WHERE user_lifetime_week = 2 ORDER BY cohort_week_start;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 12::int AS max_weeks, (CURRENT_DATE - INTERVAL '180 days') AS cohort_start),
|
|
||||||
events AS (
|
|
||||||
SELECT e."userId"::text AS user_id, e."createdAt"::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('week', e."createdAt")::date AS week_start
|
|
||||||
FROM platform."AgentGraphExecution" e WHERE e."userId" IS NOT NULL
|
|
||||||
),
|
|
||||||
first_exec AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_exec_at,
|
|
||||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
|
||||||
),
|
|
||||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
|
||||||
user_week_age AS (
|
|
||||||
SELECT aw.user_id, fe.cohort_week_start,
|
|
||||||
((aw.week_start - DATE_TRUNC('week',fe.first_exec_at)::date)/7)::int AS user_lifetime_week
|
|
||||||
FROM activity_weeks aw JOIN first_exec fe USING (user_id)
|
|
||||||
WHERE aw.week_start >= DATE_TRUNC('week',fe.first_exec_at)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_exec GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_weeks FROM params),
|
|
||||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_week_start,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_week, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
|
||||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_login_daily
|
|
||||||
-- Looker source alias: ds112 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Daily cohort retention based on login sessions.
|
|
||||||
-- Same logic as retention_login_weekly but at day granularity,
|
|
||||||
-- showing up to day 30 for cohorts from the last 90 days.
|
|
||||||
-- Useful for analysing early activation (days 1-7) in detail.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.sessions — Login session records
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS (same pattern as retention_login_weekly)
|
|
||||||
-- cohort_day_start DATE First day the cohort logged in
|
|
||||||
-- cohort_label TEXT Date string (e.g. '2025-03-01')
|
|
||||||
-- cohort_label_n TEXT Date + cohort size (e.g. '2025-03-01 (n=12)')
|
|
||||||
-- user_lifetime_day INT Days since first login (0 = signup day)
|
|
||||||
-- cohort_users BIGINT Total users in cohort
|
|
||||||
-- active_users_bounded BIGINT Users active on exactly day k
|
|
||||||
-- retained_users_unbounded BIGINT Users active any time on/after day k
|
|
||||||
-- retention_rate_bounded FLOAT bounded / cohort_users
|
|
||||||
-- retention_rate_unbounded FLOAT unbounded / cohort_users
|
|
||||||
-- cohort_users_d0 BIGINT cohort_users only at day 0, else 0 (safe to SUM)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Day-1 retention rate (came back next day)
|
|
||||||
-- SELECT cohort_label, retention_rate_bounded AS d1_retention
|
|
||||||
-- FROM analytics.retention_login_daily
|
|
||||||
-- WHERE user_lifetime_day = 1 ORDER BY cohort_day_start;
|
|
||||||
--
|
|
||||||
-- -- Average retention curve across all cohorts
|
|
||||||
-- SELECT user_lifetime_day,
|
|
||||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_d0), 0) AS avg_retention
|
|
||||||
-- FROM analytics.retention_login_daily
|
|
||||||
-- GROUP BY 1 ORDER BY 1;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 30::int AS max_days, (CURRENT_DATE - INTERVAL '90 days')::date AS cohort_start),
|
|
||||||
events AS (
|
|
||||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('day', s.created_at)::date AS day_start
|
|
||||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
|
||||||
),
|
|
||||||
first_login AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
|
||||||
DATE_TRUNC('day', MIN(created_at))::date AS cohort_day_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
HAVING MIN(created_at) >= (SELECT cohort_start FROM params)
|
|
||||||
),
|
|
||||||
activity_days AS (SELECT DISTINCT user_id, day_start FROM events),
|
|
||||||
user_day_age AS (
|
|
||||||
SELECT ad.user_id, fl.cohort_day_start,
|
|
||||||
(ad.day_start - DATE_TRUNC('day', fl.first_login_time)::date)::int AS user_lifetime_day
|
|
||||||
FROM activity_days ad JOIN first_login fl USING (user_id)
|
|
||||||
WHERE ad.day_start >= DATE_TRUNC('day', fl.first_login_time)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_day_start, user_lifetime_day, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_day_age WHERE user_lifetime_day >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_day_start, user_id, MAX(user_lifetime_day) AS last_active_day FROM user_day_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_day_start, gs AS user_lifetime_day, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_day,(SELECT max_days FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_day_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_day_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_days FROM params), GREATEST(0,(CURRENT_DATE-cs.cohort_day_start)::int)) AS cap_days
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_day_start, gs AS user_lifetime_day, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_days) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_day_start,
|
|
||||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_day_start,'YYYY-MM-DD')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_day, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_day=0 THEN g.cohort_users ELSE 0 END AS cohort_users_d0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_day_start=g.cohort_day_start AND b.user_lifetime_day=g.user_lifetime_day
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_day_start=g.cohort_day_start AND u.user_lifetime_day=g.user_lifetime_day
|
|
||||||
ORDER BY g.cohort_day_start, g.user_lifetime_day;
|
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_login_onboarded_weekly
|
|
||||||
-- Looker source alias: ds101 | Charts: 2
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Weekly cohort retention from login sessions, restricted to
|
|
||||||
-- users who "onboarded" — defined as running at least one
|
|
||||||
-- agent within 365 days of their first login.
|
|
||||||
-- Filters out users who signed up but never activated,
|
|
||||||
-- giving a cleaner view of engaged-user retention.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.sessions — Login session records
|
|
||||||
-- platform.AgentGraphExecution — Used to identify onboarders
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- Same as retention_login_weekly (cohort_week_start, user_lifetime_week,
|
|
||||||
-- retention_rate_bounded, retention_rate_unbounded, etc.)
|
|
||||||
-- Only difference: cohort is filtered to onboarded users only.
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Compare week-4 retention: all users vs onboarded only
|
|
||||||
-- SELECT 'all_users' AS segment, AVG(retention_rate_bounded) AS w4_retention
|
|
||||||
-- FROM analytics.retention_login_weekly WHERE user_lifetime_week = 4
|
|
||||||
-- UNION ALL
|
|
||||||
-- SELECT 'onboarded', AVG(retention_rate_bounded)
|
|
||||||
-- FROM analytics.retention_login_onboarded_weekly WHERE user_lifetime_week = 4;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 12::int AS max_weeks, 365::int AS onboarding_window_days),
|
|
||||||
events AS (
|
|
||||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
|
||||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
|
||||||
),
|
|
||||||
first_login_all AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
|
||||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
),
|
|
||||||
onboarders AS (
|
|
||||||
SELECT fl.user_id FROM first_login_all fl
|
|
||||||
WHERE EXISTS (
|
|
||||||
SELECT 1 FROM platform."AgentGraphExecution" e
|
|
||||||
WHERE e."userId"::text = fl.user_id
|
|
||||||
AND e."createdAt" >= fl.first_login_time
|
|
||||||
AND e."createdAt" < fl.first_login_time
|
|
||||||
+ make_interval(days => (SELECT onboarding_window_days FROM params))
|
|
||||||
)
|
|
||||||
),
|
|
||||||
first_login AS (SELECT * FROM first_login_all WHERE user_id IN (SELECT user_id FROM onboarders)),
|
|
||||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
|
||||||
user_week_age AS (
|
|
||||||
SELECT aw.user_id, fl.cohort_week_start,
|
|
||||||
((aw.week_start - DATE_TRUNC('week',fl.first_login_time)::date)/7)::int AS user_lifetime_week
|
|
||||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
|
||||||
WHERE aw.week_start >= DATE_TRUNC('week',fl.first_login_time)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_weeks FROM params),
|
|
||||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date-cs.cohort_week_start)/7)::int)) AS cap_weeks
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_week_start,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_week, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
|
||||||
ORDER BY g.cohort_week_start, g.user_lifetime_week;
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.retention_login_weekly
|
|
||||||
-- Looker source alias: ds83 | Charts: 2
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Weekly cohort retention based on login sessions.
|
|
||||||
-- Users are grouped by the ISO week of their first ever login.
|
|
||||||
-- For each cohort × lifetime-week combination, outputs both:
|
|
||||||
-- - bounded rate: % active in exactly that week
|
|
||||||
-- - unbounded rate: % who were ever active on or after that week
|
|
||||||
-- Weeks are capped to the cohort's actual age (no future data points).
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.sessions — Login session records
|
|
||||||
--
|
|
||||||
-- HOW TO READ THE OUTPUT
|
|
||||||
-- cohort_week_start The Monday of the week users first logged in
|
|
||||||
-- user_lifetime_week 0 = signup week, 1 = one week later, etc.
|
|
||||||
-- retention_rate_bounded = active_users_bounded / cohort_users
|
|
||||||
-- retention_rate_unbounded = retained_users_unbounded / cohort_users
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- cohort_week_start DATE First day of the cohort's signup week
|
|
||||||
-- cohort_label TEXT ISO week label (e.g. '2025-W01')
|
|
||||||
-- cohort_label_n TEXT ISO week label with cohort size (e.g. '2025-W01 (n=42)')
|
|
||||||
-- user_lifetime_week INT Weeks since first login (0 = signup week)
|
|
||||||
-- cohort_users BIGINT Total users in this cohort (denominator)
|
|
||||||
-- active_users_bounded BIGINT Users active in exactly week k
|
|
||||||
-- retained_users_unbounded BIGINT Users active any time on/after week k
|
|
||||||
-- retention_rate_bounded FLOAT bounded active / cohort_users
|
|
||||||
-- retention_rate_unbounded FLOAT unbounded retained / cohort_users
|
|
||||||
-- cohort_users_w0 BIGINT cohort_users only at week 0, else 0 (safe to SUM in pivot tables)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Week-1 retention rate per cohort
|
|
||||||
-- SELECT cohort_label, retention_rate_bounded AS w1_retention
|
|
||||||
-- FROM analytics.retention_login_weekly
|
|
||||||
-- WHERE user_lifetime_week = 1
|
|
||||||
-- ORDER BY cohort_week_start;
|
|
||||||
--
|
|
||||||
-- -- Overall average retention curve (all cohorts combined)
|
|
||||||
-- SELECT user_lifetime_week,
|
|
||||||
-- SUM(active_users_bounded)::float / NULLIF(SUM(cohort_users_w0), 0) AS avg_retention
|
|
||||||
-- FROM analytics.retention_login_weekly
|
|
||||||
-- GROUP BY 1 ORDER BY 1;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH params AS (SELECT 12::int AS max_weeks),
|
|
||||||
events AS (
|
|
||||||
SELECT s.user_id::text AS user_id, s.created_at::timestamptz AS created_at,
|
|
||||||
DATE_TRUNC('week', s.created_at)::date AS week_start
|
|
||||||
FROM auth.sessions s WHERE s.user_id IS NOT NULL
|
|
||||||
),
|
|
||||||
first_login AS (
|
|
||||||
SELECT user_id, MIN(created_at) AS first_login_time,
|
|
||||||
DATE_TRUNC('week', MIN(created_at))::date AS cohort_week_start
|
|
||||||
FROM events GROUP BY 1
|
|
||||||
),
|
|
||||||
activity_weeks AS (SELECT DISTINCT user_id, week_start FROM events),
|
|
||||||
user_week_age AS (
|
|
||||||
SELECT aw.user_id, fl.cohort_week_start,
|
|
||||||
((aw.week_start - DATE_TRUNC('week', fl.first_login_time)::date) / 7)::int AS user_lifetime_week
|
|
||||||
FROM activity_weeks aw JOIN first_login fl USING (user_id)
|
|
||||||
WHERE aw.week_start >= DATE_TRUNC('week', fl.first_login_time)::date
|
|
||||||
),
|
|
||||||
bounded_counts AS (
|
|
||||||
SELECT cohort_week_start, user_lifetime_week, COUNT(DISTINCT user_id) AS active_users_bounded
|
|
||||||
FROM user_week_age WHERE user_lifetime_week >= 0 GROUP BY 1,2
|
|
||||||
),
|
|
||||||
last_active AS (
|
|
||||||
SELECT cohort_week_start, user_id, MAX(user_lifetime_week) AS last_active_week FROM user_week_age GROUP BY 1,2
|
|
||||||
),
|
|
||||||
unbounded_counts AS (
|
|
||||||
SELECT la.cohort_week_start, gs AS user_lifetime_week, COUNT(*) AS retained_users_unbounded
|
|
||||||
FROM last_active la
|
|
||||||
CROSS JOIN LATERAL generate_series(0, LEAST(la.last_active_week,(SELECT max_weeks FROM params))) gs
|
|
||||||
GROUP BY 1,2
|
|
||||||
),
|
|
||||||
cohort_sizes AS (SELECT cohort_week_start, COUNT(DISTINCT user_id) AS cohort_users FROM first_login GROUP BY 1),
|
|
||||||
cohort_caps AS (
|
|
||||||
SELECT cs.cohort_week_start, cs.cohort_users,
|
|
||||||
LEAST((SELECT max_weeks FROM params),
|
|
||||||
GREATEST(0,((DATE_TRUNC('week',CURRENT_DATE)::date - cs.cohort_week_start)/7)::int)) AS cap_weeks
|
|
||||||
FROM cohort_sizes cs
|
|
||||||
),
|
|
||||||
grid AS (
|
|
||||||
SELECT cc.cohort_week_start, gs AS user_lifetime_week, cc.cohort_users
|
|
||||||
FROM cohort_caps cc CROSS JOIN LATERAL generate_series(0, cc.cap_weeks) gs
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
g.cohort_week_start,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW') AS cohort_label,
|
|
||||||
TO_CHAR(g.cohort_week_start,'IYYY-"W"IW')||' (n='||g.cohort_users||')' AS cohort_label_n,
|
|
||||||
g.user_lifetime_week, g.cohort_users,
|
|
||||||
COALESCE(b.active_users_bounded,0) AS active_users_bounded,
|
|
||||||
COALESCE(u.retained_users_unbounded,0) AS retained_users_unbounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(b.active_users_bounded,0)::float/g.cohort_users END AS retention_rate_bounded,
|
|
||||||
CASE WHEN g.cohort_users>0 THEN COALESCE(u.retained_users_unbounded,0)::float/g.cohort_users END AS retention_rate_unbounded,
|
|
||||||
CASE WHEN g.user_lifetime_week=0 THEN g.cohort_users ELSE 0 END AS cohort_users_w0
|
|
||||||
FROM grid g
|
|
||||||
LEFT JOIN bounded_counts b ON b.cohort_week_start=g.cohort_week_start AND b.user_lifetime_week=g.user_lifetime_week
|
|
||||||
LEFT JOIN unbounded_counts u ON u.cohort_week_start=g.cohort_week_start AND u.user_lifetime_week=g.user_lifetime_week
|
|
||||||
ORDER BY g.cohort_week_start, g.user_lifetime_week
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.user_block_spending
|
|
||||||
-- Looker source alias: ds6 | Charts: 5
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per credit transaction (last 90 days).
|
|
||||||
-- Shows how users spend credits broken down by block type,
|
|
||||||
-- LLM provider and model. Joins node execution stats for
|
|
||||||
-- token-level detail.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.CreditTransaction — Credit debit/credit records
|
|
||||||
-- platform.AgentNodeExecution — Node execution stats (for token counts)
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- transactionKey TEXT Unique transaction identifier
|
|
||||||
-- userId TEXT User who was charged
|
|
||||||
-- amount DECIMAL Credit amount (positive = credit, negative = debit)
|
|
||||||
-- negativeAmount DECIMAL amount * -1 (convenience for spend charts)
|
|
||||||
-- transactionType TEXT Transaction type (e.g. 'USAGE', 'REFUND', 'TOP_UP')
|
|
||||||
-- transactionTime TIMESTAMPTZ When the transaction was recorded
|
|
||||||
-- blockId TEXT Block UUID that triggered the spend
|
|
||||||
-- blockName TEXT Human-readable block name
|
|
||||||
-- llm_provider TEXT LLM provider (e.g. 'openai', 'anthropic')
|
|
||||||
-- llm_model TEXT Model name (e.g. 'gpt-4o', 'claude-3-5-sonnet')
|
|
||||||
-- node_exec_id TEXT Linked node execution UUID
|
|
||||||
-- llm_call_count INT LLM API calls made in that execution
|
|
||||||
-- llm_retry_count INT LLM retries in that execution
|
|
||||||
-- llm_input_token_count INT Input tokens consumed
|
|
||||||
-- llm_output_token_count INT Output tokens produced
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Total spend per user (last 90 days)
|
|
||||||
-- SELECT "userId", SUM("negativeAmount") AS total_spent
|
|
||||||
-- FROM analytics.user_block_spending
|
|
||||||
-- WHERE "transactionType" = 'USAGE'
|
|
||||||
-- GROUP BY 1 ORDER BY total_spent DESC;
|
|
||||||
--
|
|
||||||
-- -- Spend by LLM provider + model
|
|
||||||
-- SELECT "llm_provider", "llm_model",
|
|
||||||
-- SUM("negativeAmount") AS total_cost,
|
|
||||||
-- SUM("llm_input_token_count") AS input_tokens,
|
|
||||||
-- SUM("llm_output_token_count") AS output_tokens
|
|
||||||
-- FROM analytics.user_block_spending
|
|
||||||
-- WHERE "llm_provider" IS NOT NULL
|
|
||||||
-- GROUP BY 1, 2 ORDER BY total_cost DESC;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
c."transactionKey" AS transactionKey,
|
|
||||||
c."userId" AS userId,
|
|
||||||
c."amount" AS amount,
|
|
||||||
c."amount" * -1 AS negativeAmount,
|
|
||||||
c."type" AS transactionType,
|
|
||||||
c."createdAt" AS transactionTime,
|
|
||||||
c.metadata->>'block_id' AS blockId,
|
|
||||||
c.metadata->>'block' AS blockName,
|
|
||||||
c.metadata->'input'->'credentials'->>'provider' AS llm_provider,
|
|
||||||
c.metadata->'input'->>'model' AS llm_model,
|
|
||||||
c.metadata->>'node_exec_id' AS node_exec_id,
|
|
||||||
(ne."stats"->>'llm_call_count')::int AS llm_call_count,
|
|
||||||
(ne."stats"->>'llm_retry_count')::int AS llm_retry_count,
|
|
||||||
(ne."stats"->>'input_token_count')::int AS llm_input_token_count,
|
|
||||||
(ne."stats"->>'output_token_count')::int AS llm_output_token_count
|
|
||||||
FROM platform."CreditTransaction" c
|
|
||||||
LEFT JOIN platform."AgentNodeExecution" ne
|
|
||||||
ON (c.metadata->>'node_exec_id') = ne."id"::text
|
|
||||||
WHERE c."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.user_onboarding
|
|
||||||
-- Looker source alias: ds68 | Charts: 3
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per user onboarding record. Contains the user's
|
|
||||||
-- stated usage reason, selected integrations, completed
|
|
||||||
-- onboarding steps and optional first agent selection.
|
|
||||||
-- Full history (no date filter) since onboarding happens
|
|
||||||
-- once per user.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.UserOnboarding — Onboarding state per user
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- id TEXT Onboarding record UUID
|
|
||||||
-- createdAt TIMESTAMPTZ When onboarding started
|
|
||||||
-- updatedAt TIMESTAMPTZ Last update to onboarding state
|
|
||||||
-- usageReason TEXT Why user signed up (e.g. 'work', 'personal')
|
|
||||||
-- integrations TEXT[] Array of integration names the user selected
|
|
||||||
-- userId TEXT User UUID
|
|
||||||
-- completedSteps TEXT[] Array of onboarding step enums completed
|
|
||||||
-- selectedStoreListingVersionId TEXT First marketplace agent the user chose (if any)
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Usage reason breakdown
|
|
||||||
-- SELECT "usageReason", COUNT(*) FROM analytics.user_onboarding GROUP BY 1;
|
|
||||||
--
|
|
||||||
-- -- Completion rate per step
|
|
||||||
-- SELECT step, COUNT(*) AS users_completed
|
|
||||||
-- FROM analytics.user_onboarding
|
|
||||||
-- CROSS JOIN LATERAL UNNEST("completedSteps") AS step
|
|
||||||
-- GROUP BY 1 ORDER BY users_completed DESC;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
SELECT
|
|
||||||
id,
|
|
||||||
"createdAt",
|
|
||||||
"updatedAt",
|
|
||||||
"usageReason",
|
|
||||||
integrations,
|
|
||||||
"userId",
|
|
||||||
"completedSteps",
|
|
||||||
"selectedStoreListingVersionId"
|
|
||||||
FROM platform."UserOnboarding"
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.user_onboarding_funnel
|
|
||||||
-- Looker source alias: ds74 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Pre-aggregated onboarding funnel showing how many users
|
|
||||||
-- completed each step and the drop-off percentage from the
|
|
||||||
-- previous step. One row per onboarding step (all 22 steps
|
|
||||||
-- always present, even with 0 completions — prevents sparse
|
|
||||||
-- gaps from making LAG compare the wrong predecessors).
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.UserOnboarding — Onboarding records with completedSteps array
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- step TEXT Onboarding step enum name (e.g. 'WELCOME', 'CONGRATS')
|
|
||||||
-- step_order INT Numeric position in the funnel (1=first, 22=last)
|
|
||||||
-- users_completed BIGINT Distinct users who completed this step
|
|
||||||
-- pct_from_prev NUMERIC % of users from the previous step who reached this one
|
|
||||||
--
|
|
||||||
-- STEP ORDER
|
|
||||||
-- 1 WELCOME 9 MARKETPLACE_VISIT 17 SCHEDULE_AGENT
|
|
||||||
-- 2 USAGE_REASON 10 MARKETPLACE_ADD_AGENT 18 RUN_AGENTS
|
|
||||||
-- 3 INTEGRATIONS 11 MARKETPLACE_RUN_AGENT 19 RUN_3_DAYS
|
|
||||||
-- 4 AGENT_CHOICE 12 BUILDER_OPEN 20 TRIGGER_WEBHOOK
|
|
||||||
-- 5 AGENT_NEW_RUN 13 BUILDER_SAVE_AGENT 21 RUN_14_DAYS
|
|
||||||
-- 6 AGENT_INPUT 14 BUILDER_RUN_AGENT 22 RUN_AGENTS_100
|
|
||||||
-- 7 CONGRATS 15 VISIT_COPILOT
|
|
||||||
-- 8 GET_RESULTS 16 RE_RUN_AGENT
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Users who started onboarding in the last 90 days
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Full funnel
|
|
||||||
-- SELECT * FROM analytics.user_onboarding_funnel ORDER BY step_order;
|
|
||||||
--
|
|
||||||
-- -- Biggest drop-off point
|
|
||||||
-- SELECT step, pct_from_prev FROM analytics.user_onboarding_funnel
|
|
||||||
-- ORDER BY pct_from_prev ASC LIMIT 3;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH all_steps AS (
|
|
||||||
-- Complete ordered grid of all 22 steps so zero-completion steps
|
|
||||||
-- are always present, keeping LAG comparisons correct.
|
|
||||||
SELECT step_name, step_order
|
|
||||||
FROM (VALUES
|
|
||||||
('WELCOME', 1),
|
|
||||||
('USAGE_REASON', 2),
|
|
||||||
('INTEGRATIONS', 3),
|
|
||||||
('AGENT_CHOICE', 4),
|
|
||||||
('AGENT_NEW_RUN', 5),
|
|
||||||
('AGENT_INPUT', 6),
|
|
||||||
('CONGRATS', 7),
|
|
||||||
('GET_RESULTS', 8),
|
|
||||||
('MARKETPLACE_VISIT', 9),
|
|
||||||
('MARKETPLACE_ADD_AGENT', 10),
|
|
||||||
('MARKETPLACE_RUN_AGENT', 11),
|
|
||||||
('BUILDER_OPEN', 12),
|
|
||||||
('BUILDER_SAVE_AGENT', 13),
|
|
||||||
('BUILDER_RUN_AGENT', 14),
|
|
||||||
('VISIT_COPILOT', 15),
|
|
||||||
('RE_RUN_AGENT', 16),
|
|
||||||
('SCHEDULE_AGENT', 17),
|
|
||||||
('RUN_AGENTS', 18),
|
|
||||||
('RUN_3_DAYS', 19),
|
|
||||||
('TRIGGER_WEBHOOK', 20),
|
|
||||||
('RUN_14_DAYS', 21),
|
|
||||||
('RUN_AGENTS_100', 22)
|
|
||||||
) AS t(step_name, step_order)
|
|
||||||
),
|
|
||||||
raw AS (
|
|
||||||
SELECT
|
|
||||||
u."userId",
|
|
||||||
step_txt::text AS step
|
|
||||||
FROM platform."UserOnboarding" u
|
|
||||||
CROSS JOIN LATERAL UNNEST(u."completedSteps") AS step_txt
|
|
||||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
),
|
|
||||||
step_counts AS (
|
|
||||||
SELECT step, COUNT(DISTINCT "userId") AS users_completed
|
|
||||||
FROM raw GROUP BY step
|
|
||||||
),
|
|
||||||
funnel AS (
|
|
||||||
SELECT
|
|
||||||
a.step_name AS step,
|
|
||||||
a.step_order,
|
|
||||||
COALESCE(sc.users_completed, 0) AS users_completed,
|
|
||||||
ROUND(
|
|
||||||
100.0 * COALESCE(sc.users_completed, 0)
|
|
||||||
/ NULLIF(
|
|
||||||
LAG(COALESCE(sc.users_completed, 0)) OVER (ORDER BY a.step_order),
|
|
||||||
0
|
|
||||||
),
|
|
||||||
2
|
|
||||||
) AS pct_from_prev
|
|
||||||
FROM all_steps a
|
|
||||||
LEFT JOIN step_counts sc ON sc.step = a.step_name
|
|
||||||
)
|
|
||||||
SELECT * FROM funnel ORDER BY step_order
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.user_onboarding_integration
|
|
||||||
-- Looker source alias: ds75 | Charts: 1
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- Pre-aggregated count of users who selected each integration
|
|
||||||
-- during onboarding. One row per integration type, sorted
|
|
||||||
-- by popularity.
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- platform.UserOnboarding — integrations array column
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- integration TEXT Integration name (e.g. 'github', 'slack', 'notion')
|
|
||||||
-- users_with_integration BIGINT Distinct users who selected this integration
|
|
||||||
--
|
|
||||||
-- WINDOW
|
|
||||||
-- Users who started onboarding in the last 90 days
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Full integration popularity ranking
|
|
||||||
-- SELECT * FROM analytics.user_onboarding_integration;
|
|
||||||
--
|
|
||||||
-- -- Top 5 integrations
|
|
||||||
-- SELECT * FROM analytics.user_onboarding_integration LIMIT 5;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH exploded AS (
|
|
||||||
SELECT
|
|
||||||
u."userId" AS user_id,
|
|
||||||
UNNEST(u."integrations") AS integration
|
|
||||||
FROM platform."UserOnboarding" u
|
|
||||||
WHERE u."createdAt" >= CURRENT_DATE - INTERVAL '90 days'
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
integration,
|
|
||||||
COUNT(DISTINCT user_id) AS users_with_integration
|
|
||||||
FROM exploded
|
|
||||||
WHERE integration IS NOT NULL AND integration <> ''
|
|
||||||
GROUP BY integration
|
|
||||||
ORDER BY users_with_integration DESC
|
|
||||||
@@ -1,145 +0,0 @@
|
|||||||
-- =============================================================
|
|
||||||
-- View: analytics.users_activities
|
|
||||||
-- Looker source alias: ds56 | Charts: 5
|
|
||||||
-- =============================================================
|
|
||||||
-- DESCRIPTION
|
|
||||||
-- One row per user with lifetime activity summary.
|
|
||||||
-- Joins login sessions with agent graphs, executions and
|
|
||||||
-- node-level runs to give a full picture of how engaged
|
|
||||||
-- each user is. Includes a convenience flag for 7-day
|
|
||||||
-- activation (did the user return at least 7 days after
|
|
||||||
-- their first login?).
|
|
||||||
--
|
|
||||||
-- SOURCE TABLES
|
|
||||||
-- auth.sessions — Login/session records
|
|
||||||
-- platform.AgentGraph — Graphs (agents) built by the user
|
|
||||||
-- platform.AgentGraphExecution — Agent run history
|
|
||||||
-- platform.AgentNodeExecution — Individual block execution history
|
|
||||||
--
|
|
||||||
-- PERFORMANCE NOTE
|
|
||||||
-- Each CTE aggregates its own table independently by userId.
|
|
||||||
-- This avoids the fan-out that occurs when driving every join
|
|
||||||
-- from user_logins across the two largest tables
|
|
||||||
-- (AgentGraphExecution and AgentNodeExecution).
|
|
||||||
--
|
|
||||||
-- OUTPUT COLUMNS
|
|
||||||
-- user_id TEXT Supabase user UUID
|
|
||||||
-- first_login_time TIMESTAMPTZ First ever session created_at
|
|
||||||
-- last_login_time TIMESTAMPTZ Most recent session created_at
|
|
||||||
-- last_visit_time TIMESTAMPTZ Max of last refresh or login
|
|
||||||
-- last_agent_save_time TIMESTAMPTZ Last time user saved an agent graph
|
|
||||||
-- agent_count BIGINT Number of distinct active graphs built (0 if none)
|
|
||||||
-- first_agent_run_time TIMESTAMPTZ First ever graph execution
|
|
||||||
-- last_agent_run_time TIMESTAMPTZ Most recent graph execution
|
|
||||||
-- unique_agent_runs BIGINT Distinct agent graphs ever run (0 if none)
|
|
||||||
-- agent_runs BIGINT Total graph execution count (0 if none)
|
|
||||||
-- node_execution_count BIGINT Total node executions across all runs
|
|
||||||
-- node_execution_failed BIGINT Node executions with FAILED status
|
|
||||||
-- node_execution_completed BIGINT Node executions with COMPLETED status
|
|
||||||
-- node_execution_terminated BIGINT Node executions with TERMINATED status
|
|
||||||
-- node_execution_queued BIGINT Node executions with QUEUED status
|
|
||||||
-- node_execution_running BIGINT Node executions with RUNNING status
|
|
||||||
-- is_active_after_7d INT 1=returned after day 7, 0=did not, NULL=too early to tell
|
|
||||||
-- node_execution_incomplete BIGINT Node executions with INCOMPLETE status
|
|
||||||
-- node_execution_review BIGINT Node executions with REVIEW status
|
|
||||||
--
|
|
||||||
-- EXAMPLE QUERIES
|
|
||||||
-- -- Users who ran at least one agent and returned after 7 days
|
|
||||||
-- SELECT COUNT(*) FROM analytics.users_activities
|
|
||||||
-- WHERE agent_runs > 0 AND is_active_after_7d = 1;
|
|
||||||
--
|
|
||||||
-- -- Top 10 most active users by agent runs
|
|
||||||
-- SELECT user_id, agent_runs, node_execution_count
|
|
||||||
-- FROM analytics.users_activities
|
|
||||||
-- ORDER BY agent_runs DESC LIMIT 10;
|
|
||||||
--
|
|
||||||
-- -- 7-day activation rate
|
|
||||||
-- SELECT
|
|
||||||
-- SUM(CASE WHEN is_active_after_7d = 1 THEN 1 ELSE 0 END)::float
|
|
||||||
-- / NULLIF(COUNT(CASE WHEN is_active_after_7d IS NOT NULL THEN 1 END), 0)
|
|
||||||
-- AS activation_rate
|
|
||||||
-- FROM analytics.users_activities;
|
|
||||||
-- =============================================================
|
|
||||||
|
|
||||||
WITH user_logins AS (
|
|
||||||
SELECT
|
|
||||||
user_id::text AS user_id,
|
|
||||||
MIN(created_at) AS first_login_time,
|
|
||||||
MAX(created_at) AS last_login_time,
|
|
||||||
GREATEST(
|
|
||||||
MAX(refreshed_at)::timestamptz,
|
|
||||||
MAX(created_at)::timestamptz
|
|
||||||
) AS last_visit_time
|
|
||||||
FROM auth.sessions
|
|
||||||
GROUP BY user_id
|
|
||||||
),
|
|
||||||
user_agents AS (
|
|
||||||
-- Aggregate AgentGraph directly by userId (no fan-out from user_logins)
|
|
||||||
SELECT
|
|
||||||
"userId"::text AS user_id,
|
|
||||||
MAX("updatedAt") AS last_agent_save_time,
|
|
||||||
COUNT(DISTINCT "id") AS agent_count
|
|
||||||
FROM platform."AgentGraph"
|
|
||||||
WHERE "isActive"
|
|
||||||
GROUP BY "userId"
|
|
||||||
),
|
|
||||||
user_graph_runs AS (
|
|
||||||
-- Aggregate AgentGraphExecution directly by userId
|
|
||||||
SELECT
|
|
||||||
"userId"::text AS user_id,
|
|
||||||
MIN("createdAt") AS first_agent_run_time,
|
|
||||||
MAX("createdAt") AS last_agent_run_time,
|
|
||||||
COUNT(DISTINCT "agentGraphId") AS unique_agent_runs,
|
|
||||||
COUNT("id") AS agent_runs
|
|
||||||
FROM platform."AgentGraphExecution"
|
|
||||||
GROUP BY "userId"
|
|
||||||
),
|
|
||||||
user_node_runs AS (
|
|
||||||
-- Aggregate AgentNodeExecution directly; resolve userId via a
|
|
||||||
-- single join to AgentGraphExecution instead of fanning out from
|
|
||||||
-- user_logins through both large tables.
|
|
||||||
SELECT
|
|
||||||
g."userId"::text AS user_id,
|
|
||||||
COUNT(*) AS node_execution_count,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'FAILED') AS node_execution_failed,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'COMPLETED') AS node_execution_completed,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'TERMINATED') AS node_execution_terminated,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'QUEUED') AS node_execution_queued,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'RUNNING') AS node_execution_running,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'INCOMPLETE') AS node_execution_incomplete,
|
|
||||||
COUNT(*) FILTER (WHERE n."executionStatus" = 'REVIEW') AS node_execution_review
|
|
||||||
FROM platform."AgentNodeExecution" n
|
|
||||||
JOIN platform."AgentGraphExecution" g
|
|
||||||
ON g."id" = n."agentGraphExecutionId"
|
|
||||||
GROUP BY g."userId"
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
ul.user_id,
|
|
||||||
ul.first_login_time,
|
|
||||||
ul.last_login_time,
|
|
||||||
ul.last_visit_time,
|
|
||||||
ua.last_agent_save_time,
|
|
||||||
COALESCE(ua.agent_count, 0) AS agent_count,
|
|
||||||
gr.first_agent_run_time,
|
|
||||||
gr.last_agent_run_time,
|
|
||||||
COALESCE(gr.unique_agent_runs, 0) AS unique_agent_runs,
|
|
||||||
COALESCE(gr.agent_runs, 0) AS agent_runs,
|
|
||||||
COALESCE(nr.node_execution_count, 0) AS node_execution_count,
|
|
||||||
COALESCE(nr.node_execution_failed, 0) AS node_execution_failed,
|
|
||||||
COALESCE(nr.node_execution_completed, 0) AS node_execution_completed,
|
|
||||||
COALESCE(nr.node_execution_terminated, 0) AS node_execution_terminated,
|
|
||||||
COALESCE(nr.node_execution_queued, 0) AS node_execution_queued,
|
|
||||||
COALESCE(nr.node_execution_running, 0) AS node_execution_running,
|
|
||||||
CASE
|
|
||||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
|
||||||
AND ul.last_visit_time >= ul.first_login_time + INTERVAL '7 days' THEN 1
|
|
||||||
WHEN ul.first_login_time < NOW() - INTERVAL '7 days'
|
|
||||||
AND ul.last_visit_time < ul.first_login_time + INTERVAL '7 days' THEN 0
|
|
||||||
ELSE NULL
|
|
||||||
END AS is_active_after_7d,
|
|
||||||
COALESCE(nr.node_execution_incomplete, 0) AS node_execution_incomplete,
|
|
||||||
COALESCE(nr.node_execution_review, 0) AS node_execution_review
|
|
||||||
FROM user_logins ul
|
|
||||||
LEFT JOIN user_agents ua ON ul.user_id = ua.user_id
|
|
||||||
LEFT JOIN user_graph_runs gr ON ul.user_id = gr.user_id
|
|
||||||
LEFT JOIN user_node_runs nr ON ul.user_id = nr.user_id
|
|
||||||
@@ -37,10 +37,6 @@ JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
|||||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||||
|
|
||||||
## ===== SIGNUP / INVITE GATE ===== ##
|
|
||||||
# Set to true to require an invite before users can sign up
|
|
||||||
ENABLE_INVITE_GATE=false
|
|
||||||
|
|
||||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||||
# Platform URLs (set these for webhooks and OAuth to work)
|
# Platform URLs (set these for webhooks and OAuth to work)
|
||||||
PLATFORM_BASE_URL=http://localhost:8000
|
PLATFORM_BASE_URL=http://localhost:8000
|
||||||
@@ -194,8 +190,5 @@ ZEROBOUNCE_API_KEY=
|
|||||||
POSTHOG_API_KEY=
|
POSTHOG_API_KEY=
|
||||||
POSTHOG_HOST=https://eu.i.posthog.com
|
POSTHOG_HOST=https://eu.i.posthog.com
|
||||||
|
|
||||||
# Tally Form Integration (pre-populate business understanding on signup)
|
|
||||||
TALLY_API_KEY=
|
|
||||||
|
|
||||||
# Other Services
|
# Other Services
|
||||||
AUTOMOD_API_KEY=
|
AUTOMOD_API_KEY=
|
||||||
|
|||||||
@@ -58,31 +58,10 @@ poetry run pytest path/to/test.py --snapshot-update
|
|||||||
- **Authentication**: JWT-based with Supabase integration
|
- **Authentication**: JWT-based with Supabase integration
|
||||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||||
|
|
||||||
## Code Style
|
|
||||||
|
|
||||||
- **Top-level imports only** — no local/inner imports (lazy imports only for heavy optional deps like `openpyxl`)
|
|
||||||
- **No duck typing** — no `hasattr`/`getattr`/`isinstance` for type dispatch; use typed interfaces/unions/protocols
|
|
||||||
- **Pydantic models** over dataclass/namedtuple/dict for structured data
|
|
||||||
- **No linter suppressors** — no `# type: ignore`, `# noqa`, `# pyright: ignore`; fix the type/code
|
|
||||||
- **List comprehensions** over manual loop-and-append
|
|
||||||
- **Early return** — guard clauses first, avoid deep nesting
|
|
||||||
- **Lazy `%s` logging** — `logger.info("Processing %s items", count)` not `logger.info(f"Processing {count} items")`
|
|
||||||
- **Sanitize error paths** — `os.path.basename()` in error messages to avoid leaking directory structure
|
|
||||||
- **TOCTOU awareness** — avoid check-then-act patterns for file access and credit charging
|
|
||||||
- **`Security()` vs `Depends()`** — use `Security()` for auth deps to get proper OpenAPI security spec
|
|
||||||
- **Redis pipelines** — `transaction=True` for atomicity on multi-step operations
|
|
||||||
- **`max(0, value)` guards** — for computed values that should never be negative
|
|
||||||
- **SSE protocol** — `data:` lines for frontend-parsed events (must match Zod schema), `: comment` lines for heartbeats/status
|
|
||||||
- **File length** — keep files under ~300 lines; if a file grows beyond this, split by responsibility (e.g. extract helpers, models, or a sub-module into a new file). Never keep appending to a long file.
|
|
||||||
- **Function length** — keep functions under ~40 lines; extract named helpers when a function grows longer. Long functions are a sign of mixed concerns, not complexity.
|
|
||||||
|
|
||||||
## Testing Approach
|
## Testing Approach
|
||||||
|
|
||||||
- Uses pytest with snapshot testing for API responses
|
- Uses pytest with snapshot testing for API responses
|
||||||
- Test files are colocated with source files (`*_test.py`)
|
- Test files are colocated with source files (`*_test.py`)
|
||||||
- Mock at boundaries — mock where the symbol is **used**, not where it's **defined**
|
|
||||||
- After refactoring, update mock targets to match new module paths
|
|
||||||
- Use `AsyncMock` for async functions (`from unittest.mock import AsyncMock`)
|
|
||||||
|
|
||||||
## Database Schema
|
## Database Schema
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
|||||||
|
|
||||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||||
# for the bash_exec MCP tool (fallback when E2B is not configured).
|
# for the bash_exec MCP tool.
|
||||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
@@ -111,29 +111,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||||
# Copy Node.js installation for Prisma and agent-browser.
|
# Copy Node.js installation for Prisma
|
||||||
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
|
|
||||||
# COPY resolves them to regular files, breaking require() paths. Recreate as
|
|
||||||
# proper symlinks so npm/npx can find their modules.
|
|
||||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
|
|
||||||
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
|
|
||||||
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
|
|
||||||
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
|
|
||||||
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
|
|
||||||
fonts-liberation libfontconfig1 \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& npm install -g agent-browser \
|
|
||||||
&& agent-browser install \
|
|
||||||
&& rm -rf /tmp/* /root/.npm
|
|
||||||
|
|
||||||
WORKDIR /app/autogpt_platform/backend
|
WORKDIR /app/autogpt_platform/backend
|
||||||
|
|
||||||
# Copy only the .venv from builder (not the entire /app directory)
|
# Copy only the .venv from builder (not the entire /app directory)
|
||||||
|
|||||||
@@ -88,23 +88,20 @@ async def require_auth(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def require_permission(*permissions: APIKeyPermission):
|
def require_permission(permission: APIKeyPermission):
|
||||||
"""
|
"""
|
||||||
Dependency function for checking required permissions.
|
Dependency function for checking specific permissions
|
||||||
All listed permissions must be present.
|
|
||||||
(works with API keys and OAuth tokens)
|
(works with API keys and OAuth tokens)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def check_permissions(
|
async def check_permission(
|
||||||
auth: APIAuthorizationInfo = Security(require_auth),
|
auth: APIAuthorizationInfo = Security(require_auth),
|
||||||
) -> APIAuthorizationInfo:
|
) -> APIAuthorizationInfo:
|
||||||
missing = [p for p in permissions if p not in auth.scopes]
|
if permission not in auth.scopes:
|
||||||
if missing:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Missing required permission(s): "
|
detail=f"Missing required permission: {permission.value}",
|
||||||
f"{', '.join(p.value for p in missing)}",
|
|
||||||
)
|
)
|
||||||
return auth
|
return auth
|
||||||
|
|
||||||
return check_permissions
|
return check_permission
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Annotated, Any, Optional, Sequence
|
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, HTTPException, Security
|
from fastapi import APIRouter, Body, HTTPException, Security
|
||||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||||
@@ -9,17 +9,15 @@ from pydantic import BaseModel, Field
|
|||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
import backend.api.features.store.cache as store_cache
|
import backend.api.features.store.cache as store_cache
|
||||||
import backend.api.features.store.db as store_db
|
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
import backend.blocks
|
import backend.blocks
|
||||||
from backend.api.external.middleware import require_auth, require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data import user as user_db
|
from backend.data import user as user_db
|
||||||
from backend.data.auth.base import APIAuthorizationInfo
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from .integrations import integrations_router
|
from .integrations import integrations_router
|
||||||
@@ -97,43 +95,6 @@ async def execute_graph_block(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
|
||||||
path="/graphs",
|
|
||||||
tags=["graphs"],
|
|
||||||
status_code=201,
|
|
||||||
dependencies=[
|
|
||||||
Security(
|
|
||||||
require_permission(
|
|
||||||
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def create_graph(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
auth: APIAuthorizationInfo = Security(
|
|
||||||
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
|
|
||||||
),
|
|
||||||
) -> graph_db.GraphModel:
|
|
||||||
"""
|
|
||||||
Create a new agent graph.
|
|
||||||
|
|
||||||
The graph will be validated and assigned a new ID.
|
|
||||||
It is automatically added to the user's library.
|
|
||||||
"""
|
|
||||||
from backend.api.features.library import db as library_db
|
|
||||||
|
|
||||||
graph_model = graph_db.make_graph_model(graph, auth.user_id)
|
|
||||||
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
|
|
||||||
graph_model.validate_graph(for_run=False)
|
|
||||||
|
|
||||||
await graph_db.create_graph(graph_model, user_id=auth.user_id)
|
|
||||||
await library_db.create_library_agent(graph_model, auth.user_id)
|
|
||||||
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
|
|
||||||
|
|
||||||
return activated_graph
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||||
tags=["graphs"],
|
tags=["graphs"],
|
||||||
@@ -231,13 +192,13 @@ async def get_graph_execution_results(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/agents",
|
path="/store/agents",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
response_model=store_model.StoreAgentsResponse,
|
response_model=store_model.StoreAgentsResponse,
|
||||||
)
|
)
|
||||||
async def get_store_agents(
|
async def get_store_agents(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
creator: str | None = None,
|
creator: str | None = None,
|
||||||
sorted_by: store_db.StoreAgentsSortOptions | None = None,
|
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
@@ -279,7 +240,7 @@ async def get_store_agents(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/agents/{username}/{agent_name}",
|
path="/store/agents/{username}/{agent_name}",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
response_model=store_model.StoreAgentDetails,
|
response_model=store_model.StoreAgentDetails,
|
||||||
)
|
)
|
||||||
async def get_store_agent(
|
async def get_store_agent(
|
||||||
@@ -307,13 +268,13 @@ async def get_store_agent(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/creators",
|
path="/store/creators",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
response_model=store_model.CreatorsResponse,
|
response_model=store_model.CreatorsResponse,
|
||||||
)
|
)
|
||||||
async def get_store_creators(
|
async def get_store_creators(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
search_query: str | None = None,
|
search_query: str | None = None,
|
||||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> store_model.CreatorsResponse:
|
) -> store_model.CreatorsResponse:
|
||||||
@@ -349,7 +310,7 @@ async def get_store_creators(
|
|||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
path="/store/creators/{username}",
|
path="/store/creators/{username}",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(require_auth)], # data is public; auth required as anti-DDoS
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
response_model=store_model.CreatorDetails,
|
response_model=store_model.CreatorDetails,
|
||||||
)
|
)
|
||||||
async def get_store_creator(
|
async def get_store_creator(
|
||||||
|
|||||||
@@ -1,17 +1,8 @@
|
|||||||
from __future__ import annotations
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
|
||||||
|
|
||||||
import prisma.enums
|
|
||||||
from pydantic import BaseModel, EmailStr
|
|
||||||
|
|
||||||
from backend.data.model import UserTransaction
|
from backend.data.model import UserTransaction
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.data.invited_user import BulkInvitedUsersResult, InvitedUserRecord
|
|
||||||
|
|
||||||
|
|
||||||
class UserHistoryResponse(BaseModel):
|
class UserHistoryResponse(BaseModel):
|
||||||
"""Response model for listings with version history"""
|
"""Response model for listings with version history"""
|
||||||
@@ -23,70 +14,3 @@ class UserHistoryResponse(BaseModel):
|
|||||||
class AddUserCreditsResponse(BaseModel):
|
class AddUserCreditsResponse(BaseModel):
|
||||||
new_balance: int
|
new_balance: int
|
||||||
transaction_key: str
|
transaction_key: str
|
||||||
|
|
||||||
|
|
||||||
class CreateInvitedUserRequest(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class InvitedUserResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
email: str
|
|
||||||
status: prisma.enums.InvitedUserStatus
|
|
||||||
auth_user_id: Optional[str] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
tally_understanding: Optional[dict[str, Any]] = None
|
|
||||||
tally_status: prisma.enums.TallyComputationStatus
|
|
||||||
tally_computed_at: Optional[datetime] = None
|
|
||||||
tally_error: Optional[str] = None
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_record(cls, record: InvitedUserRecord) -> InvitedUserResponse:
|
|
||||||
return cls.model_validate(record.model_dump())
|
|
||||||
|
|
||||||
|
|
||||||
class InvitedUsersResponse(BaseModel):
|
|
||||||
invited_users: list[InvitedUserResponse]
|
|
||||||
pagination: Pagination
|
|
||||||
|
|
||||||
|
|
||||||
class BulkInvitedUserRowResponse(BaseModel):
|
|
||||||
row_number: int
|
|
||||||
email: Optional[str] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
status: Literal["CREATED", "SKIPPED", "ERROR"]
|
|
||||||
message: str
|
|
||||||
invited_user: Optional[InvitedUserResponse] = None
|
|
||||||
|
|
||||||
|
|
||||||
class BulkInvitedUsersResponse(BaseModel):
|
|
||||||
created_count: int
|
|
||||||
skipped_count: int
|
|
||||||
error_count: int
|
|
||||||
results: list[BulkInvitedUserRowResponse]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_result(cls, result: BulkInvitedUsersResult) -> BulkInvitedUsersResponse:
|
|
||||||
return cls(
|
|
||||||
created_count=result.created_count,
|
|
||||||
skipped_count=result.skipped_count,
|
|
||||||
error_count=result.error_count,
|
|
||||||
results=[
|
|
||||||
BulkInvitedUserRowResponse(
|
|
||||||
row_number=row.row_number,
|
|
||||||
email=row.email,
|
|
||||||
name=row.name,
|
|
||||||
status=row.status,
|
|
||||||
message=row.message,
|
|
||||||
invited_user=(
|
|
||||||
InvitedUserResponse.from_record(row.invited_user)
|
|
||||||
if row.invited_user is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for row in result.results
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -24,13 +24,14 @@ router = fastapi.APIRouter(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/listings",
|
"/listings",
|
||||||
summary="Get Admin Listings History",
|
summary="Get Admin Listings History",
|
||||||
|
response_model=store_model.StoreListingsWithVersionsResponse,
|
||||||
)
|
)
|
||||||
async def get_admin_listings_with_versions(
|
async def get_admin_listings_with_versions(
|
||||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||||
search: typing.Optional[str] = None,
|
search: typing.Optional[str] = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
) -> store_model.StoreListingsWithVersionsAdminViewResponse:
|
):
|
||||||
"""
|
"""
|
||||||
Get store listings with their version history for admins.
|
Get store listings with their version history for admins.
|
||||||
|
|
||||||
@@ -44,26 +45,36 @@ async def get_admin_listings_with_versions(
|
|||||||
page_size: Number of items per page
|
page_size: Number of items per page
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Paginated listings with their versions
|
StoreListingsWithVersionsResponse with listings and their versions
|
||||||
"""
|
"""
|
||||||
listings = await store_db.get_admin_listings_with_versions(
|
try:
|
||||||
status=status,
|
listings = await store_db.get_admin_listings_with_versions(
|
||||||
search_query=search,
|
status=status,
|
||||||
page=page,
|
search_query=search,
|
||||||
page_size=page_size,
|
page=page,
|
||||||
)
|
page_size=page_size,
|
||||||
return listings
|
)
|
||||||
|
return listings
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error getting admin listings with versions: %s", e)
|
||||||
|
return fastapi.responses.JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"detail": "An error occurred while retrieving listings with versions"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/submissions/{store_listing_version_id}/review",
|
"/submissions/{store_listing_version_id}/review",
|
||||||
summary="Review Store Submission",
|
summary="Review Store Submission",
|
||||||
|
response_model=store_model.StoreSubmission,
|
||||||
)
|
)
|
||||||
async def review_submission(
|
async def review_submission(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
request: store_model.ReviewSubmissionRequest,
|
request: store_model.ReviewSubmissionRequest,
|
||||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> store_model.StoreSubmissionAdminView:
|
):
|
||||||
"""
|
"""
|
||||||
Review a store listing submission.
|
Review a store listing submission.
|
||||||
|
|
||||||
@@ -73,24 +84,31 @@ async def review_submission(
|
|||||||
user_id: Authenticated admin user performing the review
|
user_id: Authenticated admin user performing the review
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StoreSubmissionAdminView with updated review information
|
StoreSubmission with updated review information
|
||||||
"""
|
"""
|
||||||
already_approved = await store_db.check_submission_already_approved(
|
try:
|
||||||
store_listing_version_id=store_listing_version_id,
|
already_approved = await store_db.check_submission_already_approved(
|
||||||
)
|
store_listing_version_id=store_listing_version_id,
|
||||||
submission = await store_db.review_store_submission(
|
)
|
||||||
store_listing_version_id=store_listing_version_id,
|
submission = await store_db.review_store_submission(
|
||||||
is_approved=request.is_approved,
|
store_listing_version_id=store_listing_version_id,
|
||||||
external_comments=request.comments,
|
is_approved=request.is_approved,
|
||||||
internal_comments=request.internal_comments or "",
|
external_comments=request.comments,
|
||||||
reviewer_id=user_id,
|
internal_comments=request.internal_comments or "",
|
||||||
)
|
reviewer_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
state_changed = already_approved != request.is_approved
|
state_changed = already_approved != request.is_approved
|
||||||
# Clear caches whenever approval state changes, since store visibility can change
|
# Clear caches when the request is approved as it updates what is shown on the store
|
||||||
if state_changed:
|
if state_changed:
|
||||||
store_cache.clear_all_caches()
|
store_cache.clear_all_caches()
|
||||||
return submission
|
return submission
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error reviewing submission: %s", e)
|
||||||
|
return fastapi.responses.JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"detail": "An error occurred while reviewing the submission"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|||||||
@@ -1,137 +0,0 @@
|
|||||||
import logging
|
|
||||||
import math
|
|
||||||
|
|
||||||
from autogpt_libs.auth import get_user_id, requires_admin_user
|
|
||||||
from fastapi import APIRouter, File, Query, Security, UploadFile
|
|
||||||
|
|
||||||
from backend.data.invited_user import (
|
|
||||||
bulk_create_invited_users_from_file,
|
|
||||||
create_invited_user,
|
|
||||||
list_invited_users,
|
|
||||||
retry_invited_user_tally,
|
|
||||||
revoke_invited_user,
|
|
||||||
)
|
|
||||||
from backend.data.tally import mask_email
|
|
||||||
from backend.util.models import Pagination
|
|
||||||
|
|
||||||
from .model import (
|
|
||||||
BulkInvitedUsersResponse,
|
|
||||||
CreateInvitedUserRequest,
|
|
||||||
InvitedUserResponse,
|
|
||||||
InvitedUsersResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(
|
|
||||||
prefix="/admin",
|
|
||||||
tags=["users", "admin"],
|
|
||||||
dependencies=[Security(requires_admin_user)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/invited-users",
|
|
||||||
response_model=InvitedUsersResponse,
|
|
||||||
summary="List Invited Users",
|
|
||||||
)
|
|
||||||
async def get_invited_users(
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
page: int = Query(1, ge=1),
|
|
||||||
page_size: int = Query(50, ge=1, le=200),
|
|
||||||
) -> InvitedUsersResponse:
|
|
||||||
logger.info("Admin user %s requested invited users", admin_user_id)
|
|
||||||
invited_users, total = await list_invited_users(page=page, page_size=page_size)
|
|
||||||
return InvitedUsersResponse(
|
|
||||||
invited_users=[InvitedUserResponse.from_record(iu) for iu in invited_users],
|
|
||||||
pagination=Pagination(
|
|
||||||
total_items=total,
|
|
||||||
total_pages=max(1, math.ceil(total / page_size)),
|
|
||||||
current_page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/invited-users",
|
|
||||||
response_model=InvitedUserResponse,
|
|
||||||
summary="Create Invited User",
|
|
||||||
)
|
|
||||||
async def create_invited_user_route(
|
|
||||||
request: CreateInvitedUserRequest,
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
) -> InvitedUserResponse:
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s creating invited user for %s",
|
|
||||||
admin_user_id,
|
|
||||||
mask_email(request.email),
|
|
||||||
)
|
|
||||||
invited_user = await create_invited_user(request.email, request.name)
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s created invited user %s",
|
|
||||||
admin_user_id,
|
|
||||||
invited_user.id,
|
|
||||||
)
|
|
||||||
return InvitedUserResponse.from_record(invited_user)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/invited-users/bulk",
|
|
||||||
response_model=BulkInvitedUsersResponse,
|
|
||||||
summary="Bulk Create Invited Users",
|
|
||||||
operation_id="postV2BulkCreateInvitedUsers",
|
|
||||||
)
|
|
||||||
async def bulk_create_invited_users_route(
|
|
||||||
file: UploadFile = File(...),
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
) -> BulkInvitedUsersResponse:
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s bulk invited users from %s",
|
|
||||||
admin_user_id,
|
|
||||||
file.filename or "<unnamed>",
|
|
||||||
)
|
|
||||||
content = await file.read()
|
|
||||||
result = await bulk_create_invited_users_from_file(file.filename, content)
|
|
||||||
return BulkInvitedUsersResponse.from_result(result)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/invited-users/{invited_user_id}/revoke",
|
|
||||||
response_model=InvitedUserResponse,
|
|
||||||
summary="Revoke Invited User",
|
|
||||||
)
|
|
||||||
async def revoke_invited_user_route(
|
|
||||||
invited_user_id: str,
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
) -> InvitedUserResponse:
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s revoking invited user %s", admin_user_id, invited_user_id
|
|
||||||
)
|
|
||||||
invited_user = await revoke_invited_user(invited_user_id)
|
|
||||||
logger.info("Admin user %s revoked invited user %s", admin_user_id, invited_user_id)
|
|
||||||
return InvitedUserResponse.from_record(invited_user)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/invited-users/{invited_user_id}/retry-tally",
|
|
||||||
response_model=InvitedUserResponse,
|
|
||||||
summary="Retry Invited User Tally",
|
|
||||||
)
|
|
||||||
async def retry_invited_user_tally_route(
|
|
||||||
invited_user_id: str,
|
|
||||||
admin_user_id: str = Security(get_user_id),
|
|
||||||
) -> InvitedUserResponse:
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s retrying Tally seed for invited user %s",
|
|
||||||
admin_user_id,
|
|
||||||
invited_user_id,
|
|
||||||
)
|
|
||||||
invited_user = await retry_invited_user_tally(invited_user_id)
|
|
||||||
logger.info(
|
|
||||||
"Admin user %s retried Tally seed for invited user %s",
|
|
||||||
admin_user_id,
|
|
||||||
invited_user_id,
|
|
||||||
)
|
|
||||||
return InvitedUserResponse.from_record(invited_user)
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import fastapi.testclient
|
|
||||||
import prisma.enums
|
|
||||||
import pytest
|
|
||||||
import pytest_mock
|
|
||||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
|
||||||
|
|
||||||
from backend.data.invited_user import (
|
|
||||||
BulkInvitedUserRowResult,
|
|
||||||
BulkInvitedUsersResult,
|
|
||||||
InvitedUserRecord,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .user_admin_routes import router as user_admin_router
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(user_admin_router)
|
|
||||||
|
|
||||||
client = fastapi.testclient.TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_app_admin_auth(mock_jwt_admin):
|
|
||||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
|
||||||
yield
|
|
||||||
app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_invited_user() -> InvitedUserRecord:
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
return InvitedUserRecord(
|
|
||||||
id="invite-1",
|
|
||||||
email="invited@example.com",
|
|
||||||
status=prisma.enums.InvitedUserStatus.INVITED,
|
|
||||||
auth_user_id=None,
|
|
||||||
name="Invited User",
|
|
||||||
tally_understanding=None,
|
|
||||||
tally_status=prisma.enums.TallyComputationStatus.PENDING,
|
|
||||||
tally_computed_at=None,
|
|
||||||
tally_error=None,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_bulk_invited_users_result() -> BulkInvitedUsersResult:
|
|
||||||
return BulkInvitedUsersResult(
|
|
||||||
created_count=1,
|
|
||||||
skipped_count=1,
|
|
||||||
error_count=0,
|
|
||||||
results=[
|
|
||||||
BulkInvitedUserRowResult(
|
|
||||||
row_number=1,
|
|
||||||
email="invited@example.com",
|
|
||||||
name=None,
|
|
||||||
status="CREATED",
|
|
||||||
message="Invite created",
|
|
||||||
invited_user=_sample_invited_user(),
|
|
||||||
),
|
|
||||||
BulkInvitedUserRowResult(
|
|
||||||
row_number=2,
|
|
||||||
email="duplicate@example.com",
|
|
||||||
name=None,
|
|
||||||
status="SKIPPED",
|
|
||||||
message="An invited user with this email already exists",
|
|
||||||
invited_user=None,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_invited_users(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.list_invited_users",
|
|
||||||
AsyncMock(return_value=([_sample_invited_user()], 1)),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/admin/invited-users")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["invited_users"]) == 1
|
|
||||||
assert data["invited_users"][0]["email"] == "invited@example.com"
|
|
||||||
assert data["invited_users"][0]["status"] == "INVITED"
|
|
||||||
assert data["pagination"]["total_items"] == 1
|
|
||||||
assert data["pagination"]["current_page"] == 1
|
|
||||||
assert data["pagination"]["page_size"] == 50
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_invited_user(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.create_invited_user",
|
|
||||||
AsyncMock(return_value=_sample_invited_user()),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/admin/invited-users",
|
|
||||||
json={"email": "invited@example.com", "name": "Invited User"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["email"] == "invited@example.com"
|
|
||||||
assert data["name"] == "Invited User"
|
|
||||||
|
|
||||||
|
|
||||||
def test_bulk_create_invited_users(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.bulk_create_invited_users_from_file",
|
|
||||||
AsyncMock(return_value=_sample_bulk_invited_users_result()),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/admin/invited-users/bulk",
|
|
||||||
files={
|
|
||||||
"file": ("invites.txt", b"invited@example.com\nduplicate@example.com\n")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["created_count"] == 1
|
|
||||||
assert data["skipped_count"] == 1
|
|
||||||
assert data["results"][0]["status"] == "CREATED"
|
|
||||||
assert data["results"][1]["status"] == "SKIPPED"
|
|
||||||
|
|
||||||
|
|
||||||
def test_revoke_invited_user(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
revoked = _sample_invited_user().model_copy(
|
|
||||||
update={"status": prisma.enums.InvitedUserStatus.REVOKED}
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.revoke_invited_user",
|
|
||||||
AsyncMock(return_value=revoked),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post("/admin/invited-users/invite-1/revoke")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["status"] == "REVOKED"
|
|
||||||
|
|
||||||
|
|
||||||
def test_retry_invited_user_tally(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
retried = _sample_invited_user().model_copy(
|
|
||||||
update={"tally_status": prisma.enums.TallyComputationStatus.RUNNING}
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.admin.user_admin_routes.retry_invited_user_tally",
|
|
||||||
AsyncMock(return_value=retried),
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post("/admin/invited-users/invite-1/retry-tally")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["tally_status"] == "RUNNING"
|
|
||||||
@@ -1,17 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from typing import Any, Sequence, get_args, get_origin
|
from typing import Sequence
|
||||||
|
|
||||||
import prisma
|
import prisma
|
||||||
from prisma.enums import ContentType
|
|
||||||
from prisma.models import mv_suggested_blocks
|
|
||||||
|
|
||||||
import backend.api.features.library.db as library_db
|
import backend.api.features.library.db as library_db
|
||||||
import backend.api.features.library.model as library_model
|
import backend.api.features.library.model as library_model
|
||||||
import backend.api.features.store.db as store_db
|
import backend.api.features.store.db as store_db
|
||||||
import backend.api.features.store.model as store_model
|
import backend.api.features.store.model as store_model
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
from backend.blocks._base import (
|
from backend.blocks._base import (
|
||||||
AnyBlockSchema,
|
AnyBlockSchema,
|
||||||
@@ -21,6 +19,7 @@ from backend.blocks._base import (
|
|||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
|
from backend.data.db import query_raw_with_schema
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
@@ -43,16 +42,6 @@ MAX_LIBRARY_AGENT_RESULTS = 100
|
|||||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||||
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
||||||
|
|
||||||
# Boost blocks over marketplace agents in search results
|
|
||||||
BLOCK_SCORE_BOOST = 50.0
|
|
||||||
|
|
||||||
# Block IDs to exclude from search results
|
|
||||||
EXCLUDED_BLOCK_IDS = frozenset(
|
|
||||||
{
|
|
||||||
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
||||||
|
|
||||||
|
|
||||||
@@ -75,8 +64,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
|||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
# Skip disabled and excluded blocks
|
# Skip disabled blocks
|
||||||
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
# Skip blocks that don't have categories (all should have at least one)
|
# Skip blocks that don't have categories (all should have at least one)
|
||||||
if not block.categories:
|
if not block.categories:
|
||||||
@@ -127,9 +116,6 @@ def get_blocks(
|
|||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
# Skip excluded blocks
|
|
||||||
if block.id in EXCLUDED_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
# Skip blocks that don't match the category
|
# Skip blocks that don't match the category
|
||||||
if category and category not in {c.name.lower() for c in block.categories}:
|
if category and category not in {c.name.lower() for c in block.categories}:
|
||||||
continue
|
continue
|
||||||
@@ -269,25 +255,14 @@ async def _build_cached_search_results(
|
|||||||
"my_agents": 0,
|
"my_agents": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use hybrid search when query is present, otherwise list all blocks
|
block_results, block_total, integration_total = _collect_block_results(
|
||||||
if (include_blocks or include_integrations) and normalized_query:
|
normalized_query=normalized_query,
|
||||||
block_results, block_total, integration_total = await _hybrid_search_blocks(
|
include_blocks=include_blocks,
|
||||||
query=search_query,
|
include_integrations=include_integrations,
|
||||||
include_blocks=include_blocks,
|
)
|
||||||
include_integrations=include_integrations,
|
scored_items.extend(block_results)
|
||||||
)
|
total_items["blocks"] = block_total
|
||||||
scored_items.extend(block_results)
|
total_items["integrations"] = integration_total
|
||||||
total_items["blocks"] = block_total
|
|
||||||
total_items["integrations"] = integration_total
|
|
||||||
elif include_blocks or include_integrations:
|
|
||||||
# No query - list all blocks using in-memory approach
|
|
||||||
block_results, block_total, integration_total = _collect_block_results(
|
|
||||||
include_blocks=include_blocks,
|
|
||||||
include_integrations=include_integrations,
|
|
||||||
)
|
|
||||||
scored_items.extend(block_results)
|
|
||||||
total_items["blocks"] = block_total
|
|
||||||
total_items["integrations"] = integration_total
|
|
||||||
|
|
||||||
if include_library_agents:
|
if include_library_agents:
|
||||||
library_response = await library_db.list_library_agents(
|
library_response = await library_db.list_library_agents(
|
||||||
@@ -332,14 +307,10 @@ async def _build_cached_search_results(
|
|||||||
|
|
||||||
def _collect_block_results(
|
def _collect_block_results(
|
||||||
*,
|
*,
|
||||||
|
normalized_query: str,
|
||||||
include_blocks: bool,
|
include_blocks: bool,
|
||||||
include_integrations: bool,
|
include_integrations: bool,
|
||||||
) -> tuple[list[_ScoredItem], int, int]:
|
) -> tuple[list[_ScoredItem], int, int]:
|
||||||
"""
|
|
||||||
Collect all blocks for listing (no search query).
|
|
||||||
|
|
||||||
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
|
|
||||||
"""
|
|
||||||
results: list[_ScoredItem] = []
|
results: list[_ScoredItem] = []
|
||||||
block_count = 0
|
block_count = 0
|
||||||
integration_count = 0
|
integration_count = 0
|
||||||
@@ -352,10 +323,6 @@ def _collect_block_results(
|
|||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip excluded blocks
|
|
||||||
if block.id in EXCLUDED_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
block_info = block.get_info()
|
block_info = block.get_info()
|
||||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||||
is_integration = len(credentials) > 0
|
is_integration = len(credentials) > 0
|
||||||
@@ -365,6 +332,10 @@ def _collect_block_results(
|
|||||||
if not is_integration and not include_blocks:
|
if not is_integration and not include_blocks:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
score = _score_block(block, block_info, normalized_query)
|
||||||
|
if not _should_include_item(score, normalized_query):
|
||||||
|
continue
|
||||||
|
|
||||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||||
if is_integration:
|
if is_integration:
|
||||||
integration_count += 1
|
integration_count += 1
|
||||||
@@ -375,122 +346,8 @@ def _collect_block_results(
|
|||||||
_ScoredItem(
|
_ScoredItem(
|
||||||
item=block_info,
|
item=block_info,
|
||||||
filter_type=filter_type,
|
filter_type=filter_type,
|
||||||
score=BLOCK_SCORE_BOOST,
|
score=score,
|
||||||
sort_key=block_info.name.lower(),
|
sort_key=_get_item_name(block_info),
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return results, block_count, integration_count
|
|
||||||
|
|
||||||
|
|
||||||
async def _hybrid_search_blocks(
|
|
||||||
*,
|
|
||||||
query: str,
|
|
||||||
include_blocks: bool,
|
|
||||||
include_integrations: bool,
|
|
||||||
) -> tuple[list[_ScoredItem], int, int]:
|
|
||||||
"""
|
|
||||||
Search blocks using hybrid search with builder-specific filtering.
|
|
||||||
|
|
||||||
Uses unified_hybrid_search for semantic + lexical search, then applies
|
|
||||||
post-filtering for block/integration types and scoring adjustments.
|
|
||||||
|
|
||||||
Scoring:
|
|
||||||
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
|
|
||||||
to prioritize blocks over marketplace agents in combined results
|
|
||||||
- +30 for exact name match, +15 for prefix name match
|
|
||||||
- +20 if the block has an LlmModel field and the query matches an LLM model name
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The search query string
|
|
||||||
include_blocks: Whether to include regular blocks
|
|
||||||
include_integrations: Whether to include integration blocks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (scored_items, block_count, integration_count)
|
|
||||||
"""
|
|
||||||
results: list[_ScoredItem] = []
|
|
||||||
block_count = 0
|
|
||||||
integration_count = 0
|
|
||||||
|
|
||||||
if not include_blocks and not include_integrations:
|
|
||||||
return results, block_count, integration_count
|
|
||||||
|
|
||||||
normalized_query = query.strip().lower()
|
|
||||||
|
|
||||||
# Fetch more results to account for post-filtering
|
|
||||||
search_results, _ = await unified_hybrid_search(
|
|
||||||
query=query,
|
|
||||||
content_types=[ContentType.BLOCK],
|
|
||||||
page=1,
|
|
||||||
page_size=150,
|
|
||||||
min_score=0.10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load all blocks for getting BlockInfo
|
|
||||||
all_blocks = load_all_blocks()
|
|
||||||
|
|
||||||
for result in search_results:
|
|
||||||
block_id = result["content_id"]
|
|
||||||
|
|
||||||
# Skip excluded blocks
|
|
||||||
if block_id in EXCLUDED_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
metadata = result.get("metadata", {})
|
|
||||||
hybrid_score = result.get("relevance", 0.0)
|
|
||||||
|
|
||||||
# Get the actual block class
|
|
||||||
if block_id not in all_blocks:
|
|
||||||
continue
|
|
||||||
|
|
||||||
block_cls = all_blocks[block_id]
|
|
||||||
block: AnyBlockSchema = block_cls()
|
|
||||||
|
|
||||||
if block.disabled:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check block/integration filter using metadata
|
|
||||||
is_integration = metadata.get("is_integration", False)
|
|
||||||
|
|
||||||
if is_integration and not include_integrations:
|
|
||||||
continue
|
|
||||||
if not is_integration and not include_blocks:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get block info
|
|
||||||
block_info = block.get_info()
|
|
||||||
|
|
||||||
# Calculate final score: scale hybrid score and add builder-specific bonuses
|
|
||||||
# Hybrid scores are 0-1, builder scores were 0-200+
|
|
||||||
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
|
|
||||||
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
|
|
||||||
|
|
||||||
# Add LLM model match bonus
|
|
||||||
has_llm_field = metadata.get("has_llm_model_field", False)
|
|
||||||
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
|
|
||||||
final_score += 20
|
|
||||||
|
|
||||||
# Add exact/prefix match bonus for deterministic tie-breaking
|
|
||||||
name = block_info.name.lower()
|
|
||||||
if name == normalized_query:
|
|
||||||
final_score += 30
|
|
||||||
elif name.startswith(normalized_query):
|
|
||||||
final_score += 15
|
|
||||||
|
|
||||||
# Track counts
|
|
||||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
|
||||||
if is_integration:
|
|
||||||
integration_count += 1
|
|
||||||
else:
|
|
||||||
block_count += 1
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
_ScoredItem(
|
|
||||||
item=block_info,
|
|
||||||
filter_type=filter_type,
|
|
||||||
score=final_score,
|
|
||||||
sort_key=name,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -615,8 +472,6 @@ async def _get_static_counts():
|
|||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
if block.id in EXCLUDED_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
all_blocks += 1
|
all_blocks += 1
|
||||||
|
|
||||||
@@ -643,25 +498,47 @@ async def _get_static_counts():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _contains_type(annotation: Any, target: type) -> bool:
|
|
||||||
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
|
||||||
if annotation is target:
|
|
||||||
return True
|
|
||||||
origin = get_origin(annotation)
|
|
||||||
if origin is None:
|
|
||||||
return False
|
|
||||||
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
|
||||||
|
|
||||||
|
|
||||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||||
for field in schema_cls.model_fields.values():
|
for field in schema_cls.model_fields.values():
|
||||||
if _contains_type(field.annotation, LlmModel):
|
if field.annotation == LlmModel:
|
||||||
# Check if query matches any value in llm_models
|
# Check if query matches any value in llm_models
|
||||||
if any(query in name for name in llm_models):
|
if any(query in name for name in llm_models):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _score_block(
|
||||||
|
block: AnyBlockSchema,
|
||||||
|
block_info: BlockInfo,
|
||||||
|
normalized_query: str,
|
||||||
|
) -> float:
|
||||||
|
if not normalized_query:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
name = block_info.name.lower()
|
||||||
|
description = block_info.description.lower()
|
||||||
|
score = _score_primary_fields(name, description, normalized_query)
|
||||||
|
|
||||||
|
category_text = " ".join(
|
||||||
|
category.get("category", "").lower() for category in block_info.categories
|
||||||
|
)
|
||||||
|
score += _score_additional_field(category_text, normalized_query, 12, 6)
|
||||||
|
|
||||||
|
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||||
|
provider_names = [
|
||||||
|
provider.value.lower()
|
||||||
|
for info in credentials_info
|
||||||
|
for provider in info.provider
|
||||||
|
]
|
||||||
|
provider_text = " ".join(provider_names)
|
||||||
|
score += _score_additional_field(provider_text, normalized_query, 15, 6)
|
||||||
|
|
||||||
|
if _matches_llm_model(block.input_schema, normalized_query):
|
||||||
|
score += 20
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
def _score_library_agent(
|
def _score_library_agent(
|
||||||
agent: library_model.LibraryAgent,
|
agent: library_model.LibraryAgent,
|
||||||
normalized_query: str,
|
normalized_query: str,
|
||||||
@@ -768,20 +645,31 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
|||||||
return providers
|
return providers
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600, shared_cache=True)
|
@cached(ttl_seconds=3600)
|
||||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||||
"""Return the most-executed blocks from the last 14 days.
|
suggested_blocks = []
|
||||||
|
# Sum the number of executions for each block type
|
||||||
|
# Prisma cannot group by nested relations, so we do a raw query
|
||||||
|
# Calculate the cutoff timestamp
|
||||||
|
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
|
||||||
|
|
||||||
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
|
results = await query_raw_with_schema(
|
||||||
and returns the top `count` blocks sorted by execution count, excluding
|
"""
|
||||||
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
|
SELECT
|
||||||
"""
|
agent_node."agentBlockId" AS block_id,
|
||||||
results = await mv_suggested_blocks.prisma().find_many()
|
COUNT(execution.id) AS execution_count
|
||||||
|
FROM {schema_prefix}"AgentNodeExecution" execution
|
||||||
|
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
|
||||||
|
WHERE execution."endedTime" >= $1::timestamp
|
||||||
|
GROUP BY agent_node."agentBlockId"
|
||||||
|
ORDER BY execution_count DESC;
|
||||||
|
""",
|
||||||
|
timestamp_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
# Get the top blocks based on execution count
|
# Get the top blocks based on execution count
|
||||||
# But ignore Input, Output, Agent, and excluded blocks
|
# But ignore Input and Output blocks
|
||||||
blocks: list[tuple[BlockInfo, int]] = []
|
blocks: list[tuple[BlockInfo, int]] = []
|
||||||
execution_counts = {row.block_id: row.execution_count for row in results}
|
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
@@ -791,9 +679,11 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
|||||||
BlockType.AGENT,
|
BlockType.AGENT,
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
if block.id in EXCLUDED_BLOCK_IDS:
|
# Find the execution count for this block
|
||||||
continue
|
execution_count = next(
|
||||||
execution_count = execution_counts.get(block.id, 0)
|
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||||
|
0,
|
||||||
|
)
|
||||||
blocks.append((block.get_info(), execution_count))
|
blocks.append((block.get_info(), execution_count))
|
||||||
# Sort blocks by execution count
|
# Sort blocks by execution count
|
||||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class SearchEntry(BaseModel):
|
|||||||
|
|
||||||
# Suggestions
|
# Suggestions
|
||||||
class SuggestionsResponse(BaseModel):
|
class SuggestionsResponse(BaseModel):
|
||||||
|
otto_suggestions: list[str]
|
||||||
recent_searches: list[SearchEntry]
|
recent_searches: list[SearchEntry]
|
||||||
providers: list[ProviderName]
|
providers: list[ProviderName]
|
||||||
top_blocks: list[BlockInfo]
|
top_blocks: list[BlockInfo]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Annotated, Sequence, cast, get_args
|
from typing import Annotated, Sequence
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
@@ -10,8 +10,6 @@ from backend.util.models import Pagination
|
|||||||
from . import db as builder_db
|
from . import db as builder_db
|
||||||
from . import model as builder_model
|
from . import model as builder_model
|
||||||
|
|
||||||
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = fastapi.APIRouter(
|
router = fastapi.APIRouter(
|
||||||
@@ -51,6 +49,11 @@ async def get_suggestions(
|
|||||||
Get all suggestions for the Blocks Menu.
|
Get all suggestions for the Blocks Menu.
|
||||||
"""
|
"""
|
||||||
return builder_model.SuggestionsResponse(
|
return builder_model.SuggestionsResponse(
|
||||||
|
otto_suggestions=[
|
||||||
|
"What blocks do I need to get started?",
|
||||||
|
"Help me create a list",
|
||||||
|
"Help me feed my data to Google Maps",
|
||||||
|
],
|
||||||
recent_searches=await builder_db.get_recent_searches(user_id),
|
recent_searches=await builder_db.get_recent_searches(user_id),
|
||||||
providers=[
|
providers=[
|
||||||
ProviderName.TWITTER,
|
ProviderName.TWITTER,
|
||||||
@@ -148,7 +151,7 @@ async def get_providers(
|
|||||||
async def search(
|
async def search(
|
||||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
search_query: Annotated[str | None, fastapi.Query()] = None,
|
search_query: Annotated[str | None, fastapi.Query()] = None,
|
||||||
filter: Annotated[str | None, fastapi.Query()] = None,
|
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
|
||||||
search_id: Annotated[str | None, fastapi.Query()] = None,
|
search_id: Annotated[str | None, fastapi.Query()] = None,
|
||||||
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
@@ -157,20 +160,9 @@ async def search(
|
|||||||
"""
|
"""
|
||||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
Search for blocks (including integrations), marketplace agents, and user library agents.
|
||||||
"""
|
"""
|
||||||
# Parse and validate filter parameter
|
# If no filters are provided, then we will return all types
|
||||||
filters: list[builder_model.FilterType]
|
if not filter:
|
||||||
if filter:
|
filter = [
|
||||||
filter_values = [f.strip() for f in filter.split(",")]
|
|
||||||
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
|
|
||||||
if invalid_filters:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
|
|
||||||
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
|
|
||||||
)
|
|
||||||
filters = cast(list[builder_model.FilterType], filter_values)
|
|
||||||
else:
|
|
||||||
filters = [
|
|
||||||
"blocks",
|
"blocks",
|
||||||
"integrations",
|
"integrations",
|
||||||
"marketplace_agents",
|
"marketplace_agents",
|
||||||
@@ -182,7 +174,7 @@ async def search(
|
|||||||
cached_results = await builder_db.get_sorted_search_results(
|
cached_results = await builder_db.get_sorted_search_results(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
filters=filters,
|
filters=filter,
|
||||||
by_creator=by_creator,
|
by_creator=by_creator,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -204,7 +196,7 @@ async def search(
|
|||||||
user_id,
|
user_id,
|
||||||
builder_model.SearchEntry(
|
builder_model.SearchEntry(
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
filter=filters,
|
filter=filter,
|
||||||
by_creator=by_creator,
|
by_creator=by_creator,
|
||||||
search_id=search_id,
|
search_id=search_id,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -2,21 +2,23 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from prisma.models import UserWorkspaceFile
|
from pydantic import BaseModel
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
from backend.copilot import service as chat_service
|
from backend.copilot import service as chat_service
|
||||||
from backend.copilot import stream_registry
|
from backend.copilot import stream_registry
|
||||||
|
from backend.copilot.completion_handler import (
|
||||||
|
process_operation_failure,
|
||||||
|
process_operation_success,
|
||||||
|
)
|
||||||
from backend.copilot.config import ChatConfig
|
from backend.copilot.config import ChatConfig
|
||||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
|
||||||
from backend.copilot.model import (
|
from backend.copilot.model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
@@ -25,10 +27,8 @@ from backend.copilot.model import (
|
|||||||
delete_chat_session,
|
delete_chat_session,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
get_user_sessions,
|
get_user_sessions,
|
||||||
update_session_title,
|
|
||||||
)
|
)
|
||||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||||
from backend.copilot.tools.e2b_sandbox import kill_sandbox
|
|
||||||
from backend.copilot.tools.models import (
|
from backend.copilot.tools.models import (
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
@@ -44,25 +44,20 @@ from backend.copilot.tools.models import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ExecutionStartedResponse,
|
ExecutionStartedResponse,
|
||||||
InputValidationErrorResponse,
|
InputValidationErrorResponse,
|
||||||
MCPToolOutputResponse,
|
|
||||||
MCPToolsDiscoveredResponse,
|
|
||||||
NeedLoginResponse,
|
NeedLoginResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
|
OperationInProgressResponse,
|
||||||
|
OperationPendingResponse,
|
||||||
|
OperationStartedResponse,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
SuggestedGoalResponse,
|
SuggestedGoalResponse,
|
||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
from backend.copilot.tracking import track_user_message
|
from backend.copilot.tracking import track_user_message
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
from backend.data.understanding import get_business_understanding
|
|
||||||
from backend.data.workspace import get_or_create_workspace
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -91,9 +86,6 @@ class StreamChatRequest(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
is_user_message: bool = True
|
is_user_message: bool = True
|
||||||
context: dict[str, str] | None = None # {url: str, content: str}
|
context: dict[str, str] | None = None # {url: str, content: str}
|
||||||
file_ids: list[str] | None = Field(
|
|
||||||
default=None, max_length=20
|
|
||||||
) # Workspace file IDs attached to this message
|
|
||||||
|
|
||||||
|
|
||||||
class CreateSessionResponse(BaseModel):
|
class CreateSessionResponse(BaseModel):
|
||||||
@@ -107,8 +99,10 @@ class CreateSessionResponse(BaseModel):
|
|||||||
class ActiveStreamInfo(BaseModel):
|
class ActiveStreamInfo(BaseModel):
|
||||||
"""Information about an active stream for reconnection."""
|
"""Information about an active stream for reconnection."""
|
||||||
|
|
||||||
turn_id: str
|
task_id: str
|
||||||
last_message_id: str # Redis Stream message ID for resumption
|
last_message_id: str # Redis Stream message ID for resumption
|
||||||
|
operation_id: str # Operation ID for completion tracking
|
||||||
|
tool_name: str # Name of the tool being executed
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
@@ -129,7 +123,6 @@ class SessionSummaryResponse(BaseModel):
|
|||||||
created_at: str
|
created_at: str
|
||||||
updated_at: str
|
updated_at: str
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
is_processing: bool
|
|
||||||
|
|
||||||
|
|
||||||
class ListSessionsResponse(BaseModel):
|
class ListSessionsResponse(BaseModel):
|
||||||
@@ -139,25 +132,20 @@ class ListSessionsResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class CancelSessionResponse(BaseModel):
|
class CancelTaskResponse(BaseModel):
|
||||||
"""Response model for the cancel session endpoint."""
|
"""Response model for the cancel task endpoint."""
|
||||||
|
|
||||||
cancelled: bool
|
cancelled: bool
|
||||||
|
task_id: str | None = None
|
||||||
reason: str | None = None
|
reason: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class UpdateSessionTitleRequest(BaseModel):
|
class OperationCompleteRequest(BaseModel):
|
||||||
"""Request model for updating a session's title."""
|
"""Request model for external completion webhook."""
|
||||||
|
|
||||||
title: str
|
success: bool
|
||||||
|
result: dict | str | None = None
|
||||||
@field_validator("title")
|
error: str | None = None
|
||||||
@classmethod
|
|
||||||
def title_must_not_be_blank(cls, v: str) -> str:
|
|
||||||
stripped = v.strip()
|
|
||||||
if not stripped:
|
|
||||||
raise ValueError("Title must not be blank")
|
|
||||||
return stripped
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
@@ -188,28 +176,6 @@ async def list_sessions(
|
|||||||
"""
|
"""
|
||||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||||
|
|
||||||
# Batch-check Redis for active stream status on each session
|
|
||||||
processing_set: set[str] = set()
|
|
||||||
if sessions:
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
pipe = redis.pipeline(transaction=False)
|
|
||||||
for session in sessions:
|
|
||||||
pipe.hget(
|
|
||||||
f"{config.session_meta_prefix}{session.session_id}",
|
|
||||||
"status",
|
|
||||||
)
|
|
||||||
statuses = await pipe.execute()
|
|
||||||
processing_set = {
|
|
||||||
session.session_id
|
|
||||||
for session, st in zip(sessions, statuses)
|
|
||||||
if st == "running"
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to fetch processing status from Redis; " "defaulting to empty"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ListSessionsResponse(
|
return ListSessionsResponse(
|
||||||
sessions=[
|
sessions=[
|
||||||
SessionSummaryResponse(
|
SessionSummaryResponse(
|
||||||
@@ -217,7 +183,6 @@ async def list_sessions(
|
|||||||
created_at=session.started_at.isoformat(),
|
created_at=session.started_at.isoformat(),
|
||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
title=session.title,
|
title=session.title,
|
||||||
is_processing=session.session_id in processing_set,
|
|
||||||
)
|
)
|
||||||
for session in sessions
|
for session in sessions
|
||||||
],
|
],
|
||||||
@@ -291,58 +256,9 @@ async def delete_session(
|
|||||||
detail=f"Session {session_id} not found or access denied",
|
detail=f"Session {session_id} not found or access denied",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Best-effort cleanup of the E2B sandbox (if any).
|
|
||||||
# sandbox_id is in Redis; kill_sandbox() fetches it from there.
|
|
||||||
e2b_cfg = ChatConfig()
|
|
||||||
if e2b_cfg.e2b_active:
|
|
||||||
assert e2b_cfg.e2b_api_key # guaranteed by e2b_active check
|
|
||||||
try:
|
|
||||||
await kill_sandbox(session_id, e2b_cfg.e2b_api_key)
|
|
||||||
except Exception:
|
|
||||||
logger.warning(
|
|
||||||
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
|
||||||
"/sessions/{session_id}/title",
|
|
||||||
summary="Update session title",
|
|
||||||
dependencies=[Security(auth.requires_user)],
|
|
||||||
status_code=200,
|
|
||||||
responses={404: {"description": "Session not found or access denied"}},
|
|
||||||
)
|
|
||||||
async def update_session_title_route(
|
|
||||||
session_id: str,
|
|
||||||
request: UpdateSessionTitleRequest,
|
|
||||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Update the title of a chat session.
|
|
||||||
|
|
||||||
Allows the user to rename their chat session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The session ID to update.
|
|
||||||
request: Request body containing the new title.
|
|
||||||
user_id: The authenticated user's ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Status of the update.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 404 if session not found or not owned by user.
|
|
||||||
"""
|
|
||||||
success = await update_session_title(session_id, user_id, request.title)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Session {session_id} not found or access denied",
|
|
||||||
)
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}",
|
"/sessions/{session_id}",
|
||||||
)
|
)
|
||||||
@@ -354,7 +270,7 @@ async def get_session(
|
|||||||
Retrieve the details of a specific chat session.
|
Retrieve the details of a specific chat session.
|
||||||
|
|
||||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||||
If there's an active stream for this session, returns active_stream info for reconnection.
|
If there's an active stream for this session, returns the task_id for reconnection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The unique identifier for the desired chat session.
|
session_id: The unique identifier for the desired chat session.
|
||||||
@@ -372,21 +288,28 @@ async def get_session(
|
|||||||
|
|
||||||
# Check if there's an active stream for this session
|
# Check if there's an active stream for this session
|
||||||
active_stream_info = None
|
active_stream_info = None
|
||||||
active_session, last_message_id = await stream_registry.get_active_session(
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
|
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||||
)
|
)
|
||||||
if active_session:
|
if active_task:
|
||||||
# Keep the assistant message (including tool_calls) so the frontend can
|
# Filter out the in-progress assistant message from the session response.
|
||||||
# render the correct tool UI (e.g. CreateAgent with mini game).
|
# The client will receive the complete assistant response through the SSE
|
||||||
# convertChatSessionToUiMessages handles isComplete=false by setting
|
# stream replay instead, preventing duplicate content.
|
||||||
# tool parts without output to state "input-available".
|
if messages and messages[-1].get("role") == "assistant":
|
||||||
|
messages = messages[:-1]
|
||||||
|
|
||||||
|
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||||
|
# Since we filtered out the cached assistant message, the client needs
|
||||||
|
# the full stream to reconstruct the response.
|
||||||
active_stream_info = ActiveStreamInfo(
|
active_stream_info = ActiveStreamInfo(
|
||||||
turn_id=active_session.turn_id,
|
task_id=active_task.task_id,
|
||||||
last_message_id=last_message_id,
|
last_message_id="0-0",
|
||||||
|
operation_id=active_task.operation_id,
|
||||||
|
tool_name=active_task.tool_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
@@ -406,7 +329,7 @@ async def get_session(
|
|||||||
async def cancel_session_task(
|
async def cancel_session_task(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||||
) -> CancelSessionResponse:
|
) -> CancelTaskResponse:
|
||||||
"""Cancel the active streaming task for a session.
|
"""Cancel the active streaming task for a session.
|
||||||
|
|
||||||
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
||||||
@@ -415,33 +338,39 @@ async def cancel_session_task(
|
|||||||
"""
|
"""
|
||||||
await _validate_and_get_session(session_id, user_id)
|
await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
|
active_task, _ = await stream_registry.get_active_task_for_session(
|
||||||
if not active_session:
|
session_id, user_id
|
||||||
return CancelSessionResponse(cancelled=True, reason="no_active_session")
|
)
|
||||||
|
if not active_task:
|
||||||
|
return CancelTaskResponse(cancelled=False, reason="no_active_task")
|
||||||
|
|
||||||
await enqueue_cancel_task(session_id)
|
task_id = active_task.task_id
|
||||||
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
|
await enqueue_cancel_task(task_id)
|
||||||
|
logger.info(
|
||||||
|
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
|
||||||
|
f"session ...{session_id[-8:]}"
|
||||||
|
)
|
||||||
|
|
||||||
# Poll until the executor confirms the task is no longer running.
|
# Poll until the executor confirms the task is no longer running.
|
||||||
|
# Keep max_wait below typical reverse-proxy read timeouts.
|
||||||
poll_interval = 0.5
|
poll_interval = 0.5
|
||||||
max_wait = 5.0
|
max_wait = 5.0
|
||||||
waited = 0.0
|
waited = 0.0
|
||||||
while waited < max_wait:
|
while waited < max_wait:
|
||||||
await asyncio.sleep(poll_interval)
|
await asyncio.sleep(poll_interval)
|
||||||
waited += poll_interval
|
waited += poll_interval
|
||||||
session_state = await stream_registry.get_session(session_id)
|
task = await stream_registry.get_task(task_id)
|
||||||
if session_state is None or session_state.status != "running":
|
if task is None or task.status != "running":
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
|
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
|
||||||
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
|
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
|
||||||
)
|
)
|
||||||
return CancelSessionResponse(cancelled=True)
|
return CancelTaskResponse(cancelled=True, task_id=task_id)
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
|
||||||
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
|
return CancelTaskResponse(
|
||||||
|
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
|
||||||
)
|
)
|
||||||
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
|
|
||||||
return CancelSessionResponse(cancelled=True)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -461,15 +390,16 @@ async def stream_chat_post(
|
|||||||
- Tool execution results
|
- Tool execution results
|
||||||
|
|
||||||
The AI generation runs in a background task that continues even if the client disconnects.
|
The AI generation runs in a background task that continues even if the client disconnects.
|
||||||
All chunks are written to a per-turn Redis stream for reconnection support. If the client
|
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||||
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
|
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse: SSE-formatted response chunks.
|
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||||
|
containing the task_id for reconnection.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -496,38 +426,6 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enrich message with file metadata if file_ids are provided.
|
|
||||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
|
||||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
|
||||||
sanitized_file_ids: list[str] | None = None
|
|
||||||
if request.file_ids and user_id:
|
|
||||||
# Filter to valid UUIDs only to prevent DB abuse
|
|
||||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
|
||||||
|
|
||||||
if valid_ids:
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Batch query instead of N+1
|
|
||||||
files = await UserWorkspaceFile.prisma().find_many(
|
|
||||||
where={
|
|
||||||
"id": {"in": valid_ids},
|
|
||||||
"workspaceId": workspace.id,
|
|
||||||
"isDeleted": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# Only keep IDs that actually exist in the user's workspace
|
|
||||||
sanitized_file_ids = [wf.id for wf in files] or None
|
|
||||||
file_lines: list[str] = [
|
|
||||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
|
||||||
for wf in files
|
|
||||||
]
|
|
||||||
if file_lines:
|
|
||||||
files_block = (
|
|
||||||
"\n\n[Attached files]\n"
|
|
||||||
+ "\n".join(file_lines)
|
|
||||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
|
||||||
)
|
|
||||||
request.message += files_block
|
|
||||||
|
|
||||||
# Atomically append user message to session BEFORE creating task to avoid
|
# Atomically append user message to session BEFORE creating task to avoid
|
||||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||||
@@ -548,38 +446,37 @@ async def stream_chat_post(
|
|||||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
turn_id = str(uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
log_meta["turn_id"] = turn_id
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
log_meta["task_id"] = task_id
|
||||||
|
|
||||||
session_create_start = time.perf_counter()
|
task_create_start = time.perf_counter()
|
||||||
await stream_registry.create_session(
|
await stream_registry.create_task(
|
||||||
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tool_call_id="chat_stream",
|
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
turn_id=turn_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
await enqueue_copilot_task(
|
||||||
subscribe_from_id = "0-0"
|
task_id=task_id,
|
||||||
|
|
||||||
await enqueue_copilot_turn(
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
operation_id=operation_id,
|
||||||
message=request.message,
|
message=request.message,
|
||||||
turn_id=turn_id,
|
|
||||||
is_user_message=request.is_user_message,
|
is_user_message=request.is_user_message,
|
||||||
context=request.context,
|
context=request.context,
|
||||||
file_ids=sanitized_file_ids,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
@@ -594,7 +491,7 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
event_gen_start = time_module.perf_counter()
|
event_gen_start = time_module.perf_counter()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
|
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||||
f"user={user_id}",
|
f"user={user_id}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
@@ -602,12 +499,11 @@ async def stream_chat_post(
|
|||||||
first_chunk_yielded = False
|
first_chunk_yielded = False
|
||||||
chunks_yielded = 0
|
chunks_yielded = 0
|
||||||
try:
|
try:
|
||||||
# Subscribe from the position we captured before enqueuing
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
# This avoids replaying old messages while catching all new ones
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
task_id=task_id,
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
last_message_id=subscribe_from_id,
|
last_message_id="0-0", # Get all messages from the beginning
|
||||||
)
|
)
|
||||||
|
|
||||||
if subscriber_queue is None:
|
if subscriber_queue is None:
|
||||||
@@ -622,7 +518,7 @@ async def stream_chat_post(
|
|||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
chunks_yielded += 1
|
chunks_yielded += 1
|
||||||
|
|
||||||
if not first_chunk_yielded:
|
if not first_chunk_yielded:
|
||||||
@@ -690,19 +586,19 @@ async def stream_chat_post(
|
|||||||
# Unsubscribe when client disconnects or stream ends
|
# Unsubscribe when client disconnects or stream ends
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_session(
|
await stream_registry.unsubscribe_from_task(
|
||||||
session_id, subscriber_queue
|
task_id, subscriber_queue
|
||||||
)
|
)
|
||||||
except Exception as unsub_err:
|
except Exception as unsub_err:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error unsubscribing from session {session_id}: {unsub_err}",
|
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
total_time = time_module.perf_counter() - event_gen_start
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||||
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
|
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
@@ -749,21 +645,17 @@ async def resume_session_stream(
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
active_session, last_message_id = await stream_registry.get_active_session(
|
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not active_session:
|
if not active_task:
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
# Always replay from the beginning ("0-0") on resume.
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
# We can't use last_message_id because it's the latest ID in the backend
|
task_id=active_task.task_id,
|
||||||
# stream, not the latest the frontend received — the gap causes lost
|
|
||||||
# messages. The frontend deduplicates replayed content.
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_session(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
last_message_id="0-0",
|
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||||
)
|
)
|
||||||
|
|
||||||
if subscriber_queue is None:
|
if subscriber_queue is None:
|
||||||
@@ -775,7 +667,7 @@ async def resume_session_stream(
|
|||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
if chunk_count < 3:
|
if chunk_count < 3:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Resume stream chunk",
|
"Resume stream chunk",
|
||||||
@@ -799,12 +691,12 @@ async def resume_session_stream(
|
|||||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_session(
|
await stream_registry.unsubscribe_from_task(
|
||||||
session_id, subscriber_queue
|
active_task.task_id, subscriber_queue
|
||||||
)
|
)
|
||||||
except Exception as unsub_err:
|
except Exception as unsub_err:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
|
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -832,6 +724,7 @@ async def resume_session_stream(
|
|||||||
@router.patch(
|
@router.patch(
|
||||||
"/sessions/{session_id}/assign-user",
|
"/sessions/{session_id}/assign-user",
|
||||||
dependencies=[Security(auth.requires_user)],
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
status_code=200,
|
||||||
)
|
)
|
||||||
async def session_assign_user(
|
async def session_assign_user(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
@@ -854,34 +747,227 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
# ========== Suggested Prompts ==========
|
# ========== Task Streaming (SSE Reconnection) ==========
|
||||||
|
|
||||||
|
|
||||||
class SuggestedPromptsResponse(BaseModel):
|
|
||||||
"""Response model for user-specific suggested prompts."""
|
|
||||||
|
|
||||||
prompts: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/suggested-prompts",
|
"/tasks/{task_id}/stream",
|
||||||
dependencies=[Security(auth.requires_user)],
|
|
||||||
)
|
)
|
||||||
async def get_suggested_prompts(
|
async def stream_task(
|
||||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
task_id: str,
|
||||||
) -> SuggestedPromptsResponse:
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
last_message_id: str = Query(
|
||||||
|
default="0-0",
|
||||||
|
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||||
|
),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get LLM-generated suggested prompts for the authenticated user.
|
Reconnect to a long-running task's SSE stream.
|
||||||
|
|
||||||
Returns personalized quick-action prompts based on the user's
|
When a long-running operation (like agent generation) starts, the client
|
||||||
business understanding. Returns an empty list if no custom prompts
|
receives a task_id. If the connection drops, the client can reconnect
|
||||||
are available.
|
using this endpoint to resume receiving updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID from the operation_started response.
|
||||||
|
user_id: Authenticated user ID for ownership validation.
|
||||||
|
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||||
"""
|
"""
|
||||||
understanding = await get_business_understanding(user_id)
|
# Check task existence and expiry before subscribing
|
||||||
if understanding is None:
|
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||||
return SuggestedPromptsResponse(prompts=[])
|
|
||||||
|
|
||||||
return SuggestedPromptsResponse(prompts=understanding.suggested_prompts)
|
if error_code == "TASK_EXPIRED":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=410,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_EXPIRED",
|
||||||
|
"message": "This operation has expired. Please try again.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if error_code == "TASK_NOT_FOUND":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_NOT_FOUND",
|
||||||
|
"message": f"Task {task_id} not found.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate ownership if task has an owner
|
||||||
|
if task and task.user_id and user_id != task.user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"code": "ACCESS_DENIED",
|
||||||
|
"message": "You do not have access to this task.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get subscriber queue from stream registry
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id=last_message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_NOT_FOUND",
|
||||||
|
"message": f"Task {task_id} not found or access denied.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait for next chunk with timeout for heartbeats
|
||||||
|
chunk = await asyncio.wait_for(
|
||||||
|
subscriber_queue.get(), timeout=heartbeat_interval
|
||||||
|
)
|
||||||
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
# Check for finish signal
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Send heartbeat to keep connection alive
|
||||||
|
yield StreamHeartbeat().to_sse()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
# Unsubscribe when client disconnects or stream ends
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||||
|
except Exception as unsub_err:
|
||||||
|
logger.error(
|
||||||
|
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks/{task_id}",
|
||||||
|
)
|
||||||
|
async def get_task_status(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Get the status of a long-running task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID to check.
|
||||||
|
user_id: Authenticated user ID for ownership validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If task_id is not found or user doesn't have access.
|
||||||
|
"""
|
||||||
|
task = await stream_registry.get_task(task_id)
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
raise NotFoundError(f"Task {task_id} not found.")
|
||||||
|
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task.user_id and user_id != task.user_id:
|
||||||
|
raise NotFoundError(f"Task {task_id} not found.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task.task_id,
|
||||||
|
"session_id": task.session_id,
|
||||||
|
"status": task.status,
|
||||||
|
"tool_name": task.tool_name,
|
||||||
|
"operation_id": task.operation_id,
|
||||||
|
"created_at": task.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== External Completion Webhook ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/operations/{operation_id}/complete",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
async def complete_operation(
|
||||||
|
operation_id: str,
|
||||||
|
request: OperationCompleteRequest,
|
||||||
|
x_api_key: str | None = Header(default=None),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
External completion webhook for long-running operations.
|
||||||
|
|
||||||
|
Called by Agent Generator (or other services) when an operation completes.
|
||||||
|
This triggers the stream registry to publish completion and continue LLM generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: The operation ID to complete.
|
||||||
|
request: Completion payload with success status and result/error.
|
||||||
|
x_api_key: Internal API key for authentication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Status of the completion.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If API key is invalid or operation not found.
|
||||||
|
"""
|
||||||
|
# Validate internal API key - reject if not configured or invalid
|
||||||
|
if not config.internal_api_key:
|
||||||
|
logger.error(
|
||||||
|
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Webhook not available: internal API key not configured",
|
||||||
|
)
|
||||||
|
if x_api_key != config.internal_api_key:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
# Find task by operation_id
|
||||||
|
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Operation {operation_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received completion webhook for operation {operation_id} "
|
||||||
|
f"(task_id={task.task_id}, success={request.success})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.success:
|
||||||
|
await process_operation_success(task, request.result)
|
||||||
|
else:
|
||||||
|
await process_operation_failure(task, request.error)
|
||||||
|
|
||||||
|
return {"status": "ok", "task_id": task.task_id}
|
||||||
|
|
||||||
|
|
||||||
# ========== Configuration ==========
|
# ========== Configuration ==========
|
||||||
@@ -964,8 +1050,9 @@ ToolResponseUnion = (
|
|||||||
| BlockOutputResponse
|
| BlockOutputResponse
|
||||||
| DocSearchResultsResponse
|
| DocSearchResultsResponse
|
||||||
| DocPageResponse
|
| DocPageResponse
|
||||||
| MCPToolsDiscoveredResponse
|
| OperationStartedResponse
|
||||||
| MCPToolOutputResponse
|
| OperationPendingResponse
|
||||||
|
| OperationInProgressResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,310 +0,0 @@
|
|||||||
"""Tests for chat API routes: session title update, file attachment validation, and suggested prompts."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import fastapi.testclient
|
|
||||||
import pytest
|
|
||||||
import pytest_mock
|
|
||||||
|
|
||||||
from backend.api.features.chat import routes as chat_routes
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(chat_routes.router)
|
|
||||||
|
|
||||||
client = fastapi.testclient.TestClient(app)
|
|
||||||
|
|
||||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_app_auth(mock_jwt_user):
|
|
||||||
"""Setup auth overrides for all tests in this module"""
|
|
||||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
|
||||||
|
|
||||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
|
||||||
yield
|
|
||||||
app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_update_session_title(
|
|
||||||
mocker: pytest_mock.MockerFixture, *, success: bool = True
|
|
||||||
):
|
|
||||||
"""Mock update_session_title."""
|
|
||||||
return mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.update_session_title",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=success,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Update title: success ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_success(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
mock_update = _mock_update_session_title(mocker, success=True)
|
|
||||||
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": "My project"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"status": "ok"}
|
|
||||||
mock_update.assert_called_once_with("sess-1", test_user_id, "My project")
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_trims_whitespace(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
mock_update = _mock_update_session_title(mocker, success=True)
|
|
||||||
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": " trimmed "},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
mock_update.assert_called_once_with("sess-1", test_user_id, "trimmed")
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Update title: blank / whitespace-only → 422 ──────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_blank_rejected(
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Whitespace-only titles must be rejected before hitting the DB."""
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": " "},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_empty_rejected(
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": ""},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Update title: session not found or wrong user → 404 ──────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_title_not_found(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
_mock_update_session_title(mocker, success=False)
|
|
||||||
|
|
||||||
response = client.patch(
|
|
||||||
"/sessions/sess-1/title",
|
|
||||||
json={"title": "New name"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# ─── file_ids Pydantic validation ─────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_chat_rejects_too_many_file_ids():
|
|
||||||
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
|
|
||||||
response = client.post(
|
|
||||||
"/sessions/sess-1/stream",
|
|
||||||
json={
|
|
||||||
"message": "hello",
|
|
||||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
|
||||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes._validate_and_get_session",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.append_and_save_message",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mock_registry = mocker.MagicMock()
|
|
||||||
mock_registry.create_session = mocker.AsyncMock(return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.stream_registry",
|
|
||||||
mock_registry,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.track_user_message",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
|
||||||
_mock_stream_internals(mocker)
|
|
||||||
# Patch workspace lookup as imported by the routes module
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
|
||||||
return_value=type("W", (), {"id": "ws-1"})(),
|
|
||||||
)
|
|
||||||
mock_prisma = mocker.MagicMock()
|
|
||||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
|
||||||
mocker.patch(
|
|
||||||
"prisma.models.UserWorkspaceFile.prisma",
|
|
||||||
return_value=mock_prisma,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/sessions/sess-1/stream",
|
|
||||||
json={
|
|
||||||
"message": "hello",
|
|
||||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Should get past validation — 200 streaming response expected
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Non-UUID strings in file_ids should be silently filtered out
|
|
||||||
and NOT passed to the database query."""
|
|
||||||
_mock_stream_internals(mocker)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
|
||||||
return_value=type("W", (), {"id": "ws-1"})(),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_prisma = mocker.MagicMock()
|
|
||||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
|
||||||
mocker.patch(
|
|
||||||
"prisma.models.UserWorkspaceFile.prisma",
|
|
||||||
return_value=mock_prisma,
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
||||||
client.post(
|
|
||||||
"/sessions/sess-1/stream",
|
|
||||||
json={
|
|
||||||
"message": "hello",
|
|
||||||
"file_ids": [
|
|
||||||
valid_id,
|
|
||||||
"not-a-uuid",
|
|
||||||
"../../../etc/passwd",
|
|
||||||
"",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# The find_many call should only receive the one valid UUID
|
|
||||||
mock_prisma.find_many.assert_called_once()
|
|
||||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
|
||||||
assert call_kwargs["where"]["id"]["in"] == [valid_id]
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
|
||||||
"""The batch query should scope to the user's workspace."""
|
|
||||||
_mock_stream_internals(mocker)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
|
||||||
return_value=type("W", (), {"id": "my-workspace-id"})(),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_prisma = mocker.MagicMock()
|
|
||||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
|
||||||
mocker.patch(
|
|
||||||
"prisma.models.UserWorkspaceFile.prisma",
|
|
||||||
return_value=mock_prisma,
|
|
||||||
)
|
|
||||||
|
|
||||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
||||||
client.post(
|
|
||||||
"/sessions/sess-1/stream",
|
|
||||||
json={"message": "hi", "file_ids": [fid]},
|
|
||||||
)
|
|
||||||
|
|
||||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
|
||||||
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
|
|
||||||
assert call_kwargs["where"]["isDeleted"] is False
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Suggested prompts endpoint ──────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_get_business_understanding(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
*,
|
|
||||||
return_value=None,
|
|
||||||
):
|
|
||||||
"""Mock get_business_understanding."""
|
|
||||||
return mocker.patch(
|
|
||||||
"backend.api.features.chat.routes.get_business_understanding",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=return_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_suggested_prompts_returns_prompts(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""User with understanding and prompts gets them back."""
|
|
||||||
mock_understanding = MagicMock()
|
|
||||||
mock_understanding.suggested_prompts = ["Do X", "Do Y", "Do Z"]
|
|
||||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
|
||||||
|
|
||||||
response = client.get("/suggested-prompts")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"prompts": ["Do X", "Do Y", "Do Z"]}
|
|
||||||
|
|
||||||
|
|
||||||
def test_suggested_prompts_no_understanding(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""User with no understanding gets empty list."""
|
|
||||||
_mock_get_business_understanding(mocker, return_value=None)
|
|
||||||
|
|
||||||
response = client.get("/suggested-prompts")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"prompts": []}
|
|
||||||
|
|
||||||
|
|
||||||
def test_suggested_prompts_empty_prompts(
|
|
||||||
mocker: pytest_mock.MockerFixture,
|
|
||||||
test_user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""User with understanding but no prompts gets empty list."""
|
|
||||||
mock_understanding = MagicMock()
|
|
||||||
mock_understanding.suggested_prompts = []
|
|
||||||
_mock_get_business_understanding(mocker, return_value=mock_understanding)
|
|
||||||
|
|
||||||
response = client.get("/suggested-prompts")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"prompts": []}
|
|
||||||
@@ -638,7 +638,7 @@ async def test_process_review_action_auto_approve_creates_auto_approval_records(
|
|||||||
|
|
||||||
# Mock get_node_executions to return node_id mapping
|
# Mock get_node_executions to return node_id mapping
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.api.features.executions.review.routes.get_node_executions"
|
"backend.data.execution.get_node_executions"
|
||||||
)
|
)
|
||||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||||
mock_node_exec.node_exec_id = "test_node_123"
|
mock_node_exec.node_exec_id = "test_node_123"
|
||||||
@@ -936,7 +936,7 @@ async def test_process_review_action_auto_approve_only_applies_to_approved_revie
|
|||||||
|
|
||||||
# Mock get_node_executions to return node_id mapping
|
# Mock get_node_executions to return node_id mapping
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.api.features.executions.review.routes.get_node_executions"
|
"backend.data.execution.get_node_executions"
|
||||||
)
|
)
|
||||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||||
mock_node_exec.node_exec_id = "node_exec_approved"
|
mock_node_exec.node_exec_id = "node_exec_approved"
|
||||||
@@ -1148,7 +1148,7 @@ async def test_process_review_action_per_review_auto_approve_granularity(
|
|||||||
|
|
||||||
# Mock get_node_executions to return batch node data
|
# Mock get_node_executions to return batch node data
|
||||||
mock_get_node_executions = mocker.patch(
|
mock_get_node_executions = mocker.patch(
|
||||||
"backend.api.features.executions.review.routes.get_node_executions"
|
"backend.data.execution.get_node_executions"
|
||||||
)
|
)
|
||||||
# Create mock node executions for each review
|
# Create mock node executions for each review
|
||||||
mock_node_execs = []
|
mock_node_execs = []
|
||||||
|
|||||||
@@ -6,15 +6,10 @@ import autogpt_libs.auth as autogpt_auth_lib
|
|||||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
from backend.copilot.constants import (
|
|
||||||
is_copilot_synthetic_id,
|
|
||||||
parse_node_id_from_exec_id,
|
|
||||||
)
|
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionContext,
|
ExecutionContext,
|
||||||
ExecutionStatus,
|
ExecutionStatus,
|
||||||
get_graph_execution_meta,
|
get_graph_execution_meta,
|
||||||
get_node_executions,
|
|
||||||
)
|
)
|
||||||
from backend.data.graph import get_graph_settings
|
from backend.data.graph import get_graph_settings
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
@@ -27,7 +22,6 @@ from backend.data.human_review import (
|
|||||||
)
|
)
|
||||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||||
from backend.data.user import get_user_by_id
|
from backend.data.user import get_user_by_id
|
||||||
from backend.data.workspace import get_or_create_workspace
|
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
|
|
||||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||||
@@ -41,38 +35,6 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_node_ids(
|
|
||||||
node_exec_ids: list[str],
|
|
||||||
graph_exec_id: str,
|
|
||||||
is_copilot: bool,
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Resolve node_exec_id -> node_id for auto-approval records.
|
|
||||||
|
|
||||||
CoPilot synthetic IDs encode node_id in the format "{node_id}:{random}".
|
|
||||||
Graph executions look up node_id from NodeExecution records.
|
|
||||||
"""
|
|
||||||
if not node_exec_ids:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
if is_copilot:
|
|
||||||
return {neid: parse_node_id_from_exec_id(neid) for neid in node_exec_ids}
|
|
||||||
|
|
||||||
node_execs = await get_node_executions(
|
|
||||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
|
||||||
)
|
|
||||||
node_exec_map = {ne.node_exec_id: ne.node_id for ne in node_execs}
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
for neid in node_exec_ids:
|
|
||||||
if neid in node_exec_map:
|
|
||||||
result[neid] = node_exec_map[neid]
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to resolve node_id for {neid}: Node execution not found."
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/pending",
|
"/pending",
|
||||||
summary="Get Pending Reviews",
|
summary="Get Pending Reviews",
|
||||||
@@ -147,16 +109,14 @@ async def list_pending_reviews_for_execution(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Verify user owns the graph execution before returning reviews
|
# Verify user owns the graph execution before returning reviews
|
||||||
# (CoPilot synthetic IDs don't have graph execution records)
|
graph_exec = await get_graph_execution_meta(
|
||||||
if not is_copilot_synthetic_id(graph_exec_id):
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
graph_exec = await get_graph_execution_meta(
|
)
|
||||||
user_id=user_id, execution_id=graph_exec_id
|
if not graph_exec:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
)
|
)
|
||||||
if not graph_exec:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Graph execution #{graph_exec_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||||
|
|
||||||
@@ -199,26 +159,30 @@ async def process_review_action(
|
|||||||
)
|
)
|
||||||
|
|
||||||
graph_exec_id = next(iter(graph_exec_ids))
|
graph_exec_id = next(iter(graph_exec_ids))
|
||||||
is_copilot = is_copilot_synthetic_id(graph_exec_id)
|
|
||||||
|
|
||||||
# Validate execution status for graph executions (skip for CoPilot synthetic IDs)
|
# Validate execution status before processing reviews
|
||||||
if not is_copilot:
|
graph_exec_meta = await get_graph_execution_meta(
|
||||||
graph_exec_meta = await get_graph_execution_meta(
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
user_id=user_id, execution_id=graph_exec_id
|
)
|
||||||
|
|
||||||
|
if not graph_exec_meta:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only allow processing reviews if execution is paused for review
|
||||||
|
# or incomplete (partial execution with some reviews already processed)
|
||||||
|
if graph_exec_meta.status not in (
|
||||||
|
ExecutionStatus.REVIEW,
|
||||||
|
ExecutionStatus.INCOMPLETE,
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||||
|
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||||
|
f"Current status: {graph_exec_meta.status}",
|
||||||
)
|
)
|
||||||
if not graph_exec_meta:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Graph execution #{graph_exec_id} not found",
|
|
||||||
)
|
|
||||||
if graph_exec_meta.status not in (
|
|
||||||
ExecutionStatus.REVIEW,
|
|
||||||
ExecutionStatus.INCOMPLETE,
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build review decisions map and track which reviews requested auto-approval
|
# Build review decisions map and track which reviews requested auto-approval
|
||||||
# Auto-approved reviews use original data (no modifications allowed)
|
# Auto-approved reviews use original data (no modifications allowed)
|
||||||
@@ -271,7 +235,7 @@ async def process_review_action(
|
|||||||
)
|
)
|
||||||
return (node_id, False)
|
return (node_id, False)
|
||||||
|
|
||||||
# Collect node_exec_ids that need auto-approval and resolve their node_ids
|
# Collect node_exec_ids that need auto-approval
|
||||||
node_exec_ids_needing_auto_approval = [
|
node_exec_ids_needing_auto_approval = [
|
||||||
node_exec_id
|
node_exec_id
|
||||||
for node_exec_id, review_result in updated_reviews.items()
|
for node_exec_id, review_result in updated_reviews.items()
|
||||||
@@ -279,16 +243,29 @@ async def process_review_action(
|
|||||||
and auto_approve_requests.get(node_exec_id, False)
|
and auto_approve_requests.get(node_exec_id, False)
|
||||||
]
|
]
|
||||||
|
|
||||||
node_id_map = await _resolve_node_ids(
|
# Batch-fetch node executions to get node_ids
|
||||||
node_exec_ids_needing_auto_approval, graph_exec_id, is_copilot
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deduplicate by node_id — one auto-approval per node
|
|
||||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
if node_exec_ids_needing_auto_approval:
|
||||||
node_id = node_id_map.get(node_exec_id)
|
from backend.data.execution import get_node_executions
|
||||||
if node_id and node_id not in nodes_needing_auto_approval:
|
|
||||||
nodes_needing_auto_approval[node_id] = updated_reviews[node_exec_id]
|
node_execs = await get_node_executions(
|
||||||
|
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||||
|
)
|
||||||
|
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||||
|
|
||||||
|
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||||
|
node_exec = node_exec_map.get(node_exec_id)
|
||||||
|
if node_exec:
|
||||||
|
review_result = updated_reviews[node_exec_id]
|
||||||
|
# Use the first approved review for this node (deduplicate by node_id)
|
||||||
|
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||||
|
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||||
|
f"Node execution not found. This may indicate a race condition "
|
||||||
|
f"or data inconsistency."
|
||||||
|
)
|
||||||
|
|
||||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||||
auto_approval_results = await asyncio.gather(
|
auto_approval_results = await asyncio.gather(
|
||||||
@@ -303,11 +280,13 @@ async def process_review_action(
|
|||||||
auto_approval_failed_count = 0
|
auto_approval_failed_count = 0
|
||||||
for result in auto_approval_results:
|
for result in auto_approval_results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
|
# Unexpected exception during auto-approval creation
|
||||||
auto_approval_failed_count += 1
|
auto_approval_failed_count += 1
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Unexpected exception during auto-approval creation: {result}"
|
f"Unexpected exception during auto-approval creation: {result}"
|
||||||
)
|
)
|
||||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||||
|
# Auto-approval creation failed (returned False)
|
||||||
auto_approval_failed_count += 1
|
auto_approval_failed_count += 1
|
||||||
|
|
||||||
# Count results
|
# Count results
|
||||||
@@ -322,31 +301,30 @@ async def process_review_action(
|
|||||||
if review.status == ReviewStatus.REJECTED
|
if review.status == ReviewStatus.REJECTED
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resume graph execution only for real graph executions (not CoPilot)
|
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||||
# CoPilot sessions are resumed by the LLM retrying run_block with review_id
|
if updated_reviews:
|
||||||
if not is_copilot and updated_reviews:
|
|
||||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||||
|
|
||||||
if not still_has_pending:
|
if not still_has_pending:
|
||||||
|
# Get the graph_id from any processed review
|
||||||
first_review = next(iter(updated_reviews.values()))
|
first_review = next(iter(updated_reviews.values()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Fetch user and settings to build complete execution context
|
||||||
user = await get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
settings = await get_graph_settings(
|
settings = await get_graph_settings(
|
||||||
user_id=user_id, graph_id=first_review.graph_id
|
user_id=user_id, graph_id=first_review.graph_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Preserve user's timezone preference when resuming execution
|
||||||
user_timezone = (
|
user_timezone = (
|
||||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
)
|
)
|
||||||
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
|
|
||||||
execution_context = ExecutionContext(
|
execution_context = ExecutionContext(
|
||||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||||
user_timezone=user_timezone,
|
user_timezone=user_timezone,
|
||||||
workspace_id=workspace.id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await add_graph_execution(
|
await add_graph_execution(
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ import prisma.enums
|
|||||||
import prisma.models
|
import prisma.models
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import backend.api.features.store.exceptions
|
||||||
from backend.data.db import connect
|
from backend.data.db import connect
|
||||||
from backend.data.includes import library_agent_include
|
from backend.data.includes import library_agent_include
|
||||||
|
|
||||||
@@ -143,7 +144,6 @@ async def test_add_agent_to_library(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
|
||||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||||
return_value=mock_library_agent_data
|
return_value=mock_library_agent_data
|
||||||
@@ -178,6 +178,7 @@ async def test_add_agent_to_library(mocker):
|
|||||||
"agentGraphVersion": 1,
|
"agentGraphVersion": 1,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
include={"AgentGraph": True},
|
||||||
)
|
)
|
||||||
# Check that create was called with the expected data including settings
|
# Check that create was called with the expected data including settings
|
||||||
create_call_args = mock_library_agent.return_value.create.call_args
|
create_call_args = mock_library_agent.return_value.create.call_args
|
||||||
@@ -217,7 +218,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Call function and verify exception
|
# Call function and verify exception
|
||||||
with pytest.raises(db.NotFoundError):
|
with pytest.raises(backend.api.features.store.exceptions.AgentNotFoundError):
|
||||||
await db.add_store_agent_to_library("version123", "test-user")
|
await db.add_store_agent_to_library("version123", "test-user")
|
||||||
|
|
||||||
# Verify mock called correctly
|
# Verify mock called correctly
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
class FolderValidationError(Exception):
|
|
||||||
"""Raised when folder operations fail validation."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FolderAlreadyExistsError(FolderValidationError):
|
|
||||||
"""Raised when a folder with the same name already exists in the location."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
@@ -26,95 +26,6 @@ class LibraryAgentStatus(str, Enum):
|
|||||||
ERROR = "ERROR"
|
ERROR = "ERROR"
|
||||||
|
|
||||||
|
|
||||||
# === Folder Models ===
|
|
||||||
|
|
||||||
|
|
||||||
class LibraryFolder(pydantic.BaseModel):
|
|
||||||
"""Represents a folder for organizing library agents."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
name: str
|
|
||||||
icon: str | None = None
|
|
||||||
color: str | None = None
|
|
||||||
parent_id: str | None = None
|
|
||||||
created_at: datetime.datetime
|
|
||||||
updated_at: datetime.datetime
|
|
||||||
agent_count: int = 0 # Direct agents in folder
|
|
||||||
subfolder_count: int = 0 # Direct child folders
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_db(
|
|
||||||
folder: prisma.models.LibraryFolder,
|
|
||||||
agent_count: int = 0,
|
|
||||||
subfolder_count: int = 0,
|
|
||||||
) -> "LibraryFolder":
|
|
||||||
"""Factory method that constructs a LibraryFolder from a Prisma model."""
|
|
||||||
return LibraryFolder(
|
|
||||||
id=folder.id,
|
|
||||||
user_id=folder.userId,
|
|
||||||
name=folder.name,
|
|
||||||
icon=folder.icon,
|
|
||||||
color=folder.color,
|
|
||||||
parent_id=folder.parentId,
|
|
||||||
created_at=folder.createdAt,
|
|
||||||
updated_at=folder.updatedAt,
|
|
||||||
agent_count=agent_count,
|
|
||||||
subfolder_count=subfolder_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LibraryFolderTree(LibraryFolder):
|
|
||||||
"""Folder with nested children for tree view."""
|
|
||||||
|
|
||||||
children: list["LibraryFolderTree"] = []
|
|
||||||
|
|
||||||
|
|
||||||
class FolderCreateRequest(pydantic.BaseModel):
|
|
||||||
"""Request model for creating a folder."""
|
|
||||||
|
|
||||||
name: str = pydantic.Field(..., min_length=1, max_length=100)
|
|
||||||
icon: str | None = None
|
|
||||||
color: str | None = pydantic.Field(
|
|
||||||
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
|
|
||||||
)
|
|
||||||
parent_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class FolderUpdateRequest(pydantic.BaseModel):
|
|
||||||
"""Request model for updating a folder."""
|
|
||||||
|
|
||||||
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
|
|
||||||
icon: str | None = None
|
|
||||||
color: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class FolderMoveRequest(pydantic.BaseModel):
|
|
||||||
"""Request model for moving a folder to a new parent."""
|
|
||||||
|
|
||||||
target_parent_id: str | None = None # None = move to root
|
|
||||||
|
|
||||||
|
|
||||||
class BulkMoveAgentsRequest(pydantic.BaseModel):
|
|
||||||
"""Request model for moving multiple agents to a folder."""
|
|
||||||
|
|
||||||
agent_ids: list[str]
|
|
||||||
folder_id: str | None = None # None = move to root
|
|
||||||
|
|
||||||
|
|
||||||
class FolderListResponse(pydantic.BaseModel):
|
|
||||||
"""Response schema for a list of folders."""
|
|
||||||
|
|
||||||
folders: list[LibraryFolder]
|
|
||||||
pagination: Pagination
|
|
||||||
|
|
||||||
|
|
||||||
class FolderTreeResponse(pydantic.BaseModel):
|
|
||||||
"""Response schema for folder tree structure."""
|
|
||||||
|
|
||||||
tree: list[LibraryFolderTree]
|
|
||||||
|
|
||||||
|
|
||||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||||
"""Creator information for a marketplace listing."""
|
"""Creator information for a marketplace listing."""
|
||||||
|
|
||||||
@@ -165,6 +76,7 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
|
owner_user_id: str
|
||||||
|
|
||||||
image_url: str | None
|
image_url: str | None
|
||||||
|
|
||||||
@@ -205,14 +117,9 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="List of recent executions with status, score, and summary",
|
description="List of recent executions with status, score, and summary",
|
||||||
)
|
)
|
||||||
can_access_graph: bool = pydantic.Field(
|
can_access_graph: bool
|
||||||
description="Indicates whether the same user owns the corresponding graph"
|
|
||||||
)
|
|
||||||
is_latest_version: bool
|
is_latest_version: bool
|
||||||
is_favorite: bool
|
is_favorite: bool
|
||||||
folder_id: str | None = None
|
|
||||||
folder_name: str | None = None # Denormalized for display
|
|
||||||
|
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||||
@@ -325,6 +232,7 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id=agent.id,
|
id=agent.id,
|
||||||
graph_id=agent.agentGraphId,
|
graph_id=agent.agentGraphId,
|
||||||
graph_version=agent.agentGraphVersion,
|
graph_version=agent.agentGraphVersion,
|
||||||
|
owner_user_id=agent.userId,
|
||||||
image_url=agent.imageUrl,
|
image_url=agent.imageUrl,
|
||||||
creator_name=creator_name,
|
creator_name=creator_name,
|
||||||
creator_image_url=creator_image_url,
|
creator_image_url=creator_image_url,
|
||||||
@@ -351,8 +259,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
can_access_graph=can_access_graph,
|
can_access_graph=can_access_graph,
|
||||||
is_latest_version=is_latest_version,
|
is_latest_version=is_latest_version,
|
||||||
is_favorite=agent.isFavorite,
|
is_favorite=agent.isFavorite,
|
||||||
folder_id=agent.folderId,
|
|
||||||
folder_name=agent.Folder.name if agent.Folder else None,
|
|
||||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||||
settings=_parse_settings(agent.settings),
|
settings=_parse_settings(agent.settings),
|
||||||
marketplace_listing=marketplace_listing_data,
|
marketplace_listing=marketplace_listing_data,
|
||||||
@@ -564,7 +470,3 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
|||||||
settings: Optional[GraphSettings] = pydantic.Field(
|
settings: Optional[GraphSettings] = pydantic.Field(
|
||||||
default=None, description="User-specific settings for this library agent"
|
default=None, description="User-specific settings for this library agent"
|
||||||
)
|
)
|
||||||
folder_id: Optional[str] = pydantic.Field(
|
|
||||||
default=None,
|
|
||||||
description="Folder ID to move agent to (None to move to root)",
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
import fastapi
|
import fastapi
|
||||||
|
|
||||||
from .agents import router as agents_router
|
from .agents import router as agents_router
|
||||||
from .folders import router as folders_router
|
|
||||||
from .presets import router as presets_router
|
from .presets import router as presets_router
|
||||||
|
|
||||||
router = fastapi.APIRouter()
|
router = fastapi.APIRouter()
|
||||||
|
|
||||||
router.include_router(presets_router)
|
router.include_router(presets_router)
|
||||||
router.include_router(folders_router)
|
|
||||||
router.include_router(agents_router)
|
router.include_router(agents_router)
|
||||||
|
|||||||
@@ -41,14 +41,6 @@ async def list_library_agents(
|
|||||||
ge=1,
|
ge=1,
|
||||||
description="Number of agents per page (must be >= 1)",
|
description="Number of agents per page (must be >= 1)",
|
||||||
),
|
),
|
||||||
folder_id: Optional[str] = Query(
|
|
||||||
None,
|
|
||||||
description="Filter by folder ID",
|
|
||||||
),
|
|
||||||
include_root_only: bool = Query(
|
|
||||||
False,
|
|
||||||
description="Only return agents without a folder (root-level agents)",
|
|
||||||
),
|
|
||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Get all agents in the user's library (both created and saved).
|
Get all agents in the user's library (both created and saved).
|
||||||
@@ -59,8 +51,6 @@ async def list_library_agents(
|
|||||||
sort_by=sort_by,
|
sort_by=sort_by,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
folder_id=folder_id,
|
|
||||||
include_root_only=include_root_only,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -178,7 +168,6 @@ async def update_library_agent(
|
|||||||
is_favorite=payload.is_favorite,
|
is_favorite=payload.is_favorite,
|
||||||
is_archived=payload.is_archived,
|
is_archived=payload.is_archived,
|
||||||
settings=payload.settings,
|
settings=payload.settings,
|
||||||
folder_id=payload.folder_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,287 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import autogpt_libs.auth as autogpt_auth_lib
|
|
||||||
from fastapi import APIRouter, Query, Security, status
|
|
||||||
from fastapi.responses import Response
|
|
||||||
|
|
||||||
from .. import db as library_db
|
|
||||||
from .. import model as library_model
|
|
||||||
|
|
||||||
router = APIRouter(
|
|
||||||
prefix="/folders",
|
|
||||||
tags=["library", "folders", "private"],
|
|
||||||
dependencies=[Security(autogpt_auth_lib.requires_user)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"",
|
|
||||||
summary="List Library Folders",
|
|
||||||
response_model=library_model.FolderListResponse,
|
|
||||||
responses={
|
|
||||||
200: {"description": "List of folders"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def list_folders(
|
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
|
||||||
parent_id: Optional[str] = Query(
|
|
||||||
None,
|
|
||||||
description="Filter by parent folder ID. If not provided, returns root-level folders.",
|
|
||||||
),
|
|
||||||
include_relations: bool = Query(
|
|
||||||
True,
|
|
||||||
description="Include agent and subfolder relations (for counts)",
|
|
||||||
),
|
|
||||||
) -> library_model.FolderListResponse:
|
|
||||||
"""
|
|
||||||
List folders for the authenticated user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
parent_id: Optional parent folder ID to filter by.
|
|
||||||
include_relations: Whether to include agent and subfolder relations for counts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A FolderListResponse containing folders.
|
|
||||||
"""
|
|
||||||
folders = await library_db.list_folders(
|
|
||||||
user_id=user_id,
|
|
||||||
parent_id=parent_id,
|
|
||||||
include_relations=include_relations,
|
|
||||||
)
|
|
||||||
return library_model.FolderListResponse(
|
|
||||||
folders=folders,
|
|
||||||
pagination=library_model.Pagination(
|
|
||||||
total_items=len(folders),
|
|
||||||
total_pages=1,
|
|
||||||
current_page=1,
|
|
||||||
page_size=len(folders),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/tree",
|
|
||||||
summary="Get Folder Tree",
|
|
||||||
response_model=library_model.FolderTreeResponse,
|
|
||||||
responses={
|
|
||||||
200: {"description": "Folder tree structure"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def get_folder_tree(
|
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
|
||||||
) -> library_model.FolderTreeResponse:
|
|
||||||
"""
|
|
||||||
Get the full folder tree for the authenticated user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A FolderTreeResponse containing the nested folder structure.
|
|
||||||
"""
|
|
||||||
tree = await library_db.get_folder_tree(user_id=user_id)
|
|
||||||
return library_model.FolderTreeResponse(tree=tree)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/{folder_id}",
|
|
||||||
summary="Get Folder",
|
|
||||||
response_model=library_model.LibraryFolder,
|
|
||||||
responses={
|
|
||||||
200: {"description": "Folder details"},
|
|
||||||
404: {"description": "Folder not found"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def get_folder(
|
|
||||||
folder_id: str,
|
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
|
||||||
) -> library_model.LibraryFolder:
|
|
||||||
"""
|
|
||||||
Get a specific folder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folder_id: ID of the folder to retrieve.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The requested LibraryFolder.
|
|
||||||
"""
|
|
||||||
return await library_db.get_folder(folder_id=folder_id, user_id=user_id)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"",
|
|
||||||
summary="Create Folder",
|
|
||||||
status_code=status.HTTP_201_CREATED,
|
|
||||||
response_model=library_model.LibraryFolder,
|
|
||||||
responses={
|
|
||||||
201: {"description": "Folder created successfully"},
|
|
||||||
400: {"description": "Validation error"},
|
|
||||||
404: {"description": "Parent folder not found"},
|
|
||||||
409: {"description": "Folder name conflict"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def create_folder(
|
|
||||||
payload: library_model.FolderCreateRequest,
|
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
|
||||||
) -> library_model.LibraryFolder:
|
|
||||||
"""
|
|
||||||
Create a new folder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
payload: The folder creation request.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created LibraryFolder.
|
|
||||||
"""
|
|
||||||
return await library_db.create_folder(
|
|
||||||
user_id=user_id,
|
|
||||||
name=payload.name,
|
|
||||||
parent_id=payload.parent_id,
|
|
||||||
icon=payload.icon,
|
|
||||||
color=payload.color,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
|
||||||
"/{folder_id}",
|
|
||||||
summary="Update Folder",
|
|
||||||
response_model=library_model.LibraryFolder,
|
|
||||||
responses={
|
|
||||||
200: {"description": "Folder updated successfully"},
|
|
||||||
400: {"description": "Validation error"},
|
|
||||||
404: {"description": "Folder not found"},
|
|
||||||
409: {"description": "Folder name conflict"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def update_folder(
|
|
||||||
folder_id: str,
|
|
||||||
payload: library_model.FolderUpdateRequest,
|
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
|
||||||
) -> library_model.LibraryFolder:
|
|
||||||
"""
|
|
||||||
Update a folder's properties.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folder_id: ID of the folder to update.
|
|
||||||
payload: The folder update request.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The updated LibraryFolder.
|
|
||||||
"""
|
|
||||||
return await library_db.update_folder(
|
|
||||||
folder_id=folder_id,
|
|
||||||
user_id=user_id,
|
|
||||||
name=payload.name,
|
|
||||||
icon=payload.icon,
|
|
||||||
color=payload.color,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/{folder_id}/move",
|
|
||||||
summary="Move Folder",
|
|
||||||
response_model=library_model.LibraryFolder,
|
|
||||||
responses={
|
|
||||||
200: {"description": "Folder moved successfully"},
|
|
||||||
400: {"description": "Validation error (circular reference)"},
|
|
||||||
404: {"description": "Folder or target parent not found"},
|
|
||||||
409: {"description": "Folder name conflict in target location"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def move_folder(
|
|
||||||
folder_id: str,
|
|
||||||
payload: library_model.FolderMoveRequest,
|
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
|
||||||
) -> library_model.LibraryFolder:
|
|
||||||
"""
|
|
||||||
Move a folder to a new parent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folder_id: ID of the folder to move.
|
|
||||||
payload: The move request with target parent.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The moved LibraryFolder.
|
|
||||||
"""
|
|
||||||
return await library_db.move_folder(
|
|
||||||
folder_id=folder_id,
|
|
||||||
user_id=user_id,
|
|
||||||
target_parent_id=payload.target_parent_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
|
||||||
"/{folder_id}",
|
|
||||||
summary="Delete Folder",
|
|
||||||
status_code=status.HTTP_204_NO_CONTENT,
|
|
||||||
responses={
|
|
||||||
204: {"description": "Folder deleted successfully"},
|
|
||||||
404: {"description": "Folder not found"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def delete_folder(
|
|
||||||
folder_id: str,
|
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
|
||||||
) -> Response:
|
|
||||||
"""
|
|
||||||
Soft-delete a folder and all its contents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folder_id: ID of the folder to delete.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
204 No Content if successful.
|
|
||||||
"""
|
|
||||||
await library_db.delete_folder(
|
|
||||||
folder_id=folder_id,
|
|
||||||
user_id=user_id,
|
|
||||||
soft_delete=True,
|
|
||||||
)
|
|
||||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
||||||
|
|
||||||
|
|
||||||
# === Bulk Agent Operations ===
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/agents/bulk-move",
|
|
||||||
summary="Bulk Move Agents",
|
|
||||||
response_model=list[library_model.LibraryAgent],
|
|
||||||
responses={
|
|
||||||
200: {"description": "Agents moved successfully"},
|
|
||||||
404: {"description": "Folder not found"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def bulk_move_agents(
|
|
||||||
payload: library_model.BulkMoveAgentsRequest,
|
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
|
||||||
) -> list[library_model.LibraryAgent]:
|
|
||||||
"""
|
|
||||||
Move multiple agents to a folder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
payload: The bulk move request with agent IDs and target folder.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The updated LibraryAgents.
|
|
||||||
"""
|
|
||||||
return await library_db.bulk_move_agents_to_folder(
|
|
||||||
agent_ids=payload.agent_ids,
|
|
||||||
folder_id=payload.folder_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
|
|||||||
id="test-agent-1",
|
id="test-agent-1",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
|
owner_user_id=test_user_id,
|
||||||
name="Test Agent 1",
|
name="Test Agent 1",
|
||||||
description="Test Description 1",
|
description="Test Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -66,6 +67,7 @@ async def test_get_library_agents_success(
|
|||||||
id="test-agent-2",
|
id="test-agent-2",
|
||||||
graph_id="test-agent-2",
|
graph_id="test-agent-2",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
|
owner_user_id=test_user_id,
|
||||||
name="Test Agent 2",
|
name="Test Agent 2",
|
||||||
description="Test Description 2",
|
description="Test Description 2",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -113,8 +115,6 @@ async def test_get_library_agents_success(
|
|||||||
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=15,
|
page_size=15,
|
||||||
folder_id=None,
|
|
||||||
include_root_only=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -129,6 +129,7 @@ async def test_get_favorite_library_agents_success(
|
|||||||
id="test-agent-1",
|
id="test-agent-1",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
|
owner_user_id=test_user_id,
|
||||||
name="Favorite Agent 1",
|
name="Favorite Agent 1",
|
||||||
description="Test Favorite Description 1",
|
description="Test Favorite Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
@@ -181,6 +182,7 @@ def test_add_agent_to_library_success(
|
|||||||
id="test-library-agent-id",
|
id="test-library-agent-id",
|
||||||
graph_id="test-agent-1",
|
graph_id="test-agent-1",
|
||||||
graph_version=1,
|
graph_version=1,
|
||||||
|
owner_user_id=test_user_id,
|
||||||
name="Test Agent 1",
|
name="Test Agent 1",
|
||||||
description="Test Description 1",
|
description="Test Description 1",
|
||||||
image_url=None,
|
image_url=None,
|
||||||
|
|||||||
@@ -7,24 +7,20 @@ frontend can list available tools on an MCP server before placing a block.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from autogpt_libs.auth import get_user_id
|
from autogpt_libs.auth import get_user_id
|
||||||
from fastapi import Security
|
from fastapi import Security
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
from backend.api.features.integrations.router import CredentialsMetaResponse
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||||
from backend.blocks.mcp.helpers import (
|
|
||||||
auto_lookup_mcp_credential,
|
|
||||||
normalize_mcp_url,
|
|
||||||
server_host,
|
|
||||||
)
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||||
from backend.data.model import OAuth2Credentials
|
from backend.data.model import OAuth2Credentials
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.request import HTTPClientError, Requests, validate_url_host
|
from backend.util.request import HTTPClientError, Requests
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -78,20 +74,32 @@ async def discover_tools(
|
|||||||
If the user has a stored MCP credential for this server URL, it will be
|
If the user has a stored MCP credential for this server URL, it will be
|
||||||
used automatically — no need to pass an explicit auth token.
|
used automatically — no need to pass an explicit auth token.
|
||||||
"""
|
"""
|
||||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
|
||||||
try:
|
|
||||||
await validate_url_host(request.server_url)
|
|
||||||
except ValueError as e:
|
|
||||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
|
||||||
|
|
||||||
auth_token = request.auth_token
|
auth_token = request.auth_token
|
||||||
|
|
||||||
# Auto-use stored MCP credential when no explicit token is provided.
|
# Auto-use stored MCP credential when no explicit token is provided.
|
||||||
if not auth_token:
|
if not auth_token:
|
||||||
best_cred = await auto_lookup_mcp_credential(
|
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
||||||
user_id, normalize_mcp_url(request.server_url)
|
user_id, ProviderName.MCP.value
|
||||||
)
|
)
|
||||||
|
# Find the freshest credential for this server URL
|
||||||
|
best_cred: OAuth2Credentials | None = None
|
||||||
|
for cred in mcp_creds:
|
||||||
|
if (
|
||||||
|
isinstance(cred, OAuth2Credentials)
|
||||||
|
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
|
||||||
|
):
|
||||||
|
if best_cred is None or (
|
||||||
|
(cred.access_token_expires_at or 0)
|
||||||
|
> (best_cred.access_token_expires_at or 0)
|
||||||
|
):
|
||||||
|
best_cred = cred
|
||||||
if best_cred:
|
if best_cred:
|
||||||
|
# Refresh the token if expired before using it
|
||||||
|
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
||||||
|
logger.info(
|
||||||
|
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
||||||
|
f"expires_at={best_cred.access_token_expires_at}"
|
||||||
|
)
|
||||||
auth_token = best_cred.access_token.get_secret_value()
|
auth_token = best_cred.access_token.get_secret_value()
|
||||||
|
|
||||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
client = MCPClient(request.server_url, auth_token=auth_token)
|
||||||
@@ -126,7 +134,7 @@ async def discover_tools(
|
|||||||
],
|
],
|
||||||
server_name=(
|
server_name=(
|
||||||
init_result.get("serverInfo", {}).get("name")
|
init_result.get("serverInfo", {}).get("name")
|
||||||
or server_host(request.server_url)
|
or urlparse(request.server_url).hostname
|
||||||
or "MCP"
|
or "MCP"
|
||||||
),
|
),
|
||||||
protocol_version=init_result.get("protocolVersion"),
|
protocol_version=init_result.get("protocolVersion"),
|
||||||
@@ -165,16 +173,7 @@ async def mcp_oauth_login(
|
|||||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
3. Performs Dynamic Client Registration (RFC 7591) if available
|
||||||
4. Returns the authorization URL for the frontend to open in a popup
|
4. Returns the authorization URL for the frontend to open in a popup
|
||||||
"""
|
"""
|
||||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
client = MCPClient(request.server_url)
|
||||||
try:
|
|
||||||
await validate_url_host(request.server_url)
|
|
||||||
except ValueError as e:
|
|
||||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
|
||||||
|
|
||||||
# Normalize the URL so that credentials stored here are matched consistently
|
|
||||||
# by auto_lookup_mcp_credential (which also uses normalized URLs).
|
|
||||||
server_url = normalize_mcp_url(request.server_url)
|
|
||||||
client = MCPClient(server_url)
|
|
||||||
|
|
||||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
# Step 1: Discover protected-resource metadata (RFC 9728)
|
||||||
protected_resource = await client.discover_auth()
|
protected_resource = await client.discover_auth()
|
||||||
@@ -183,16 +182,7 @@ async def mcp_oauth_login(
|
|||||||
|
|
||||||
if protected_resource and protected_resource.get("authorization_servers"):
|
if protected_resource and protected_resource.get("authorization_servers"):
|
||||||
auth_server_url = protected_resource["authorization_servers"][0]
|
auth_server_url = protected_resource["authorization_servers"][0]
|
||||||
resource_url = protected_resource.get("resource", server_url)
|
resource_url = protected_resource.get("resource", request.server_url)
|
||||||
|
|
||||||
# Validate the auth server URL from metadata to prevent SSRF.
|
|
||||||
try:
|
|
||||||
await validate_url_host(auth_server_url)
|
|
||||||
except ValueError as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Invalid authorization server URL in metadata: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
# Step 2a: Discover auth-server metadata (RFC 8414)
|
||||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||||
@@ -202,7 +192,7 @@ async def mcp_oauth_login(
|
|||||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
# Don't assume a resource_url — omitting it lets the auth server choose
|
||||||
# the correct audience for the token (RFC 8707 resource is optional).
|
# the correct audience for the token (RFC 8707 resource is optional).
|
||||||
resource_url = None
|
resource_url = None
|
||||||
metadata = await client.discover_auth_server_metadata(server_url)
|
metadata = await client.discover_auth_server_metadata(request.server_url)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not metadata
|
not metadata
|
||||||
@@ -232,18 +222,12 @@ async def mcp_oauth_login(
|
|||||||
client_id = ""
|
client_id = ""
|
||||||
client_secret = ""
|
client_secret = ""
|
||||||
if registration_endpoint:
|
if registration_endpoint:
|
||||||
# Validate the registration endpoint to prevent SSRF via metadata.
|
reg_result = await _register_mcp_client(
|
||||||
try:
|
registration_endpoint, redirect_uri, request.server_url
|
||||||
await validate_url_host(registration_endpoint)
|
)
|
||||||
except ValueError:
|
if reg_result:
|
||||||
pass # Skip registration, fall back to default client_id
|
client_id = reg_result.get("client_id", "")
|
||||||
else:
|
client_secret = reg_result.get("client_secret", "")
|
||||||
reg_result = await _register_mcp_client(
|
|
||||||
registration_endpoint, redirect_uri, server_url
|
|
||||||
)
|
|
||||||
if reg_result:
|
|
||||||
client_id = reg_result.get("client_id", "")
|
|
||||||
client_secret = reg_result.get("client_secret", "")
|
|
||||||
|
|
||||||
if not client_id:
|
if not client_id:
|
||||||
client_id = "autogpt-platform"
|
client_id = "autogpt-platform"
|
||||||
@@ -261,7 +245,7 @@ async def mcp_oauth_login(
|
|||||||
"token_url": token_url,
|
"token_url": token_url,
|
||||||
"revoke_url": revoke_url,
|
"revoke_url": revoke_url,
|
||||||
"resource_url": resource_url,
|
"resource_url": resource_url,
|
||||||
"server_url": server_url,
|
"server_url": request.server_url,
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"client_secret": client_secret,
|
"client_secret": client_secret,
|
||||||
},
|
},
|
||||||
@@ -358,7 +342,7 @@ async def mcp_oauth_callback(
|
|||||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
||||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
||||||
|
|
||||||
hostname = server_host(meta["server_url"])
|
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
||||||
credentials.title = f"MCP: {hostname}"
|
credentials.title = f"MCP: {hostname}"
|
||||||
|
|
||||||
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
||||||
@@ -373,9 +357,7 @@ async def mcp_oauth_callback(
|
|||||||
):
|
):
|
||||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Removed old MCP credential %s for %s",
|
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
||||||
old.id,
|
|
||||||
server_host(meta["server_url"]),
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
||||||
@@ -393,93 +375,6 @@ async def mcp_oauth_callback(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ======================== Bearer Token ======================== #
|
|
||||||
|
|
||||||
|
|
||||||
class MCPStoreTokenRequest(BaseModel):
|
|
||||||
"""Request to store a bearer token for an MCP server that doesn't support OAuth."""
|
|
||||||
|
|
||||||
server_url: str = Field(
|
|
||||||
description="MCP server URL the token authenticates against"
|
|
||||||
)
|
|
||||||
token: SecretStr = Field(
|
|
||||||
min_length=1, description="Bearer token / API key for the MCP server"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/token",
|
|
||||||
summary="Store a bearer token for an MCP server",
|
|
||||||
)
|
|
||||||
async def mcp_store_token(
|
|
||||||
request: MCPStoreTokenRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> CredentialsMetaResponse:
|
|
||||||
"""
|
|
||||||
Store a manually provided bearer token as an MCP credential.
|
|
||||||
|
|
||||||
Used by the Copilot MCPSetupCard when the server doesn't support the MCP
|
|
||||||
OAuth discovery flow (returns 400 from /oauth/login). Subsequent
|
|
||||||
``run_mcp_tool`` calls will automatically pick up the token via
|
|
||||||
``_auto_lookup_credential``.
|
|
||||||
"""
|
|
||||||
token = request.token.get_secret_value().strip()
|
|
||||||
if not token:
|
|
||||||
raise fastapi.HTTPException(status_code=422, detail="Token must not be blank.")
|
|
||||||
|
|
||||||
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
|
|
||||||
try:
|
|
||||||
await validate_url_host(request.server_url)
|
|
||||||
except ValueError as e:
|
|
||||||
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
|
|
||||||
|
|
||||||
# Normalize URL so trailing-slash variants match existing credentials.
|
|
||||||
server_url = normalize_mcp_url(request.server_url)
|
|
||||||
hostname = server_host(server_url)
|
|
||||||
|
|
||||||
# Collect IDs of old credentials to clean up after successful create.
|
|
||||||
old_cred_ids: list[str] = []
|
|
||||||
try:
|
|
||||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
old_cred_ids = [
|
|
||||||
old.id
|
|
||||||
for old in old_creds
|
|
||||||
if isinstance(old, OAuth2Credentials)
|
|
||||||
and normalize_mcp_url((old.metadata or {}).get("mcp_server_url", ""))
|
|
||||||
== server_url
|
|
||||||
]
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not query old MCP token credentials", exc_info=True)
|
|
||||||
|
|
||||||
credentials = OAuth2Credentials(
|
|
||||||
provider=ProviderName.MCP.value,
|
|
||||||
title=f"MCP: {hostname}",
|
|
||||||
access_token=SecretStr(token),
|
|
||||||
scopes=[],
|
|
||||||
metadata={"mcp_server_url": server_url},
|
|
||||||
)
|
|
||||||
await creds_manager.create(user_id, credentials)
|
|
||||||
|
|
||||||
# Only delete old credentials after the new one is safely stored.
|
|
||||||
for old_id in old_cred_ids:
|
|
||||||
try:
|
|
||||||
await creds_manager.store.delete_creds_by_id(user_id, old_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not clean up old MCP token credential", exc_info=True)
|
|
||||||
|
|
||||||
return CredentialsMetaResponse(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=credentials.provider,
|
|
||||||
type=credentials.type,
|
|
||||||
title=credentials.title,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
username=credentials.username,
|
|
||||||
host=hostname,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================== Helpers ======================== #
|
# ======================== Helpers ======================== #
|
||||||
|
|
||||||
|
|
||||||
@@ -505,7 +400,5 @@ async def _register_mcp_client(
|
|||||||
return data
|
return data
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
||||||
"Dynamic client registration failed for %s: %s", server_host(server_url), e
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -11,11 +11,9 @@ import httpx
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from autogpt_libs.auth import get_user_id
|
from autogpt_libs.auth import get_user_id
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.api.features.mcp.routes import router
|
from backend.api.features.mcp.routes import router
|
||||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.util.request import HTTPClientError
|
from backend.util.request import HTTPClientError
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
@@ -30,16 +28,6 @@ async def client():
|
|||||||
yield c
|
yield c
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _bypass_ssrf_validation():
|
|
||||||
"""Bypass validate_url_host in all route tests (test URLs don't resolve)."""
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.mcp.routes.validate_url_host",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
class TestDiscoverTools:
|
class TestDiscoverTools:
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_discover_tools_success(self, client):
|
async def test_discover_tools_success(self, client):
|
||||||
@@ -68,12 +56,9 @@ class TestDiscoverTools:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
patch(
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=None,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||||
instance = MockClient.return_value
|
instance = MockClient.return_value
|
||||||
instance.initialize = AsyncMock(
|
instance.initialize = AsyncMock(
|
||||||
return_value={
|
return_value={
|
||||||
@@ -122,6 +107,10 @@ class TestDiscoverTools:
|
|||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
||||||
"""When no explicit token is given, stored MCP credentials are used."""
|
"""When no explicit token is given, stored MCP credentials are used."""
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
stored_cred = OAuth2Credentials(
|
stored_cred = OAuth2Credentials(
|
||||||
provider="mcp",
|
provider="mcp",
|
||||||
title="MCP: example.com",
|
title="MCP: example.com",
|
||||||
@@ -135,12 +124,10 @@ class TestDiscoverTools:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
patch(
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=stored_cred,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
||||||
|
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
||||||
instance = MockClient.return_value
|
instance = MockClient.return_value
|
||||||
instance.initialize = AsyncMock(
|
instance.initialize = AsyncMock(
|
||||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||||
@@ -162,12 +149,9 @@ class TestDiscoverTools:
|
|||||||
async def test_discover_tools_mcp_error(self, client):
|
async def test_discover_tools_mcp_error(self, client):
|
||||||
with (
|
with (
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
patch(
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=None,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||||
instance = MockClient.return_value
|
instance = MockClient.return_value
|
||||||
instance.initialize = AsyncMock(
|
instance.initialize = AsyncMock(
|
||||||
side_effect=MCPClientError("Connection refused")
|
side_effect=MCPClientError("Connection refused")
|
||||||
@@ -185,12 +169,9 @@ class TestDiscoverTools:
|
|||||||
async def test_discover_tools_generic_error(self, client):
|
async def test_discover_tools_generic_error(self, client):
|
||||||
with (
|
with (
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
patch(
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=None,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||||
instance = MockClient.return_value
|
instance = MockClient.return_value
|
||||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||||
|
|
||||||
@@ -206,12 +187,9 @@ class TestDiscoverTools:
|
|||||||
async def test_discover_tools_auth_required(self, client):
|
async def test_discover_tools_auth_required(self, client):
|
||||||
with (
|
with (
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
patch(
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=None,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||||
instance = MockClient.return_value
|
instance = MockClient.return_value
|
||||||
instance.initialize = AsyncMock(
|
instance.initialize = AsyncMock(
|
||||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||||
@@ -229,12 +207,9 @@ class TestDiscoverTools:
|
|||||||
async def test_discover_tools_forbidden(self, client):
|
async def test_discover_tools_forbidden(self, client):
|
||||||
with (
|
with (
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
patch(
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=None,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||||
instance = MockClient.return_value
|
instance = MockClient.return_value
|
||||||
instance.initialize = AsyncMock(
|
instance.initialize = AsyncMock(
|
||||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||||
@@ -356,6 +331,10 @@ class TestOAuthLogin:
|
|||||||
class TestOAuthCallback:
|
class TestOAuthCallback:
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_oauth_callback_success(self, client):
|
async def test_oauth_callback_success(self, client):
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
mock_creds = OAuth2Credentials(
|
mock_creds = OAuth2Credentials(
|
||||||
provider="mcp",
|
provider="mcp",
|
||||||
title=None,
|
title=None,
|
||||||
@@ -455,118 +434,3 @@ class TestOAuthCallback:
|
|||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "token exchange failed" in response.json()["detail"].lower()
|
assert "token exchange failed" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
class TestStoreToken:
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_store_token_success(self, client):
|
|
||||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
mock_cm.create = AsyncMock()
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/token",
|
|
||||||
json={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"token": "my-api-key-123",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["provider"] == "mcp"
|
|
||||||
assert data["type"] == "oauth2"
|
|
||||||
assert data["host"] == "mcp.example.com"
|
|
||||||
mock_cm.create.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_store_token_blank_rejected(self, client):
|
|
||||||
"""Blank token string (after stripping) should return 422."""
|
|
||||||
response = await client.post(
|
|
||||||
"/token",
|
|
||||||
json={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"token": " ",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Pydantic min_length=1 catches the whitespace-only token
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_store_token_replaces_old_credential(self, client):
|
|
||||||
old_cred = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
title="MCP: mcp.example.com",
|
|
||||||
access_token=SecretStr("old-token"),
|
|
||||||
scopes=[],
|
|
||||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[old_cred])
|
|
||||||
mock_cm.create = AsyncMock()
|
|
||||||
mock_cm.store.delete_creds_by_id = AsyncMock()
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/token",
|
|
||||||
json={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"token": "new-token",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
mock_cm.store.delete_creds_by_id.assert_called_once_with(
|
|
||||||
"test-user-id", old_cred.id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSSRFValidation:
|
|
||||||
"""Verify that validate_url_host is enforced on all endpoints."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_ssrf_blocked(self, client):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.mcp.routes.validate_url_host",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=ValueError("blocked loopback"),
|
|
||||||
):
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "http://localhost/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "blocked loopback" in response.json()["detail"].lower()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_login_ssrf_blocked(self, client):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.mcp.routes.validate_url_host",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=ValueError("blocked private IP"),
|
|
||||||
):
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "http://10.0.0.1/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "blocked private ip" in response.json()["detail"].lower()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_store_token_ssrf_blocked(self, client):
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.mcp.routes.validate_url_host",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=ValueError("blocked loopback"),
|
|
||||||
):
|
|
||||||
response = await client.post(
|
|
||||||
"/token",
|
|
||||||
json={
|
|
||||||
"server_url": "http://127.0.0.1/mcp",
|
|
||||||
"token": "some-token",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "blocked loopback" in response.json()["detail"].lower()
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
|
|
||||||
from . import db as store_db
|
from . import db as store_db
|
||||||
@@ -21,7 +23,7 @@ def clear_all_caches():
|
|||||||
async def _get_cached_store_agents(
|
async def _get_cached_store_agents(
|
||||||
featured: bool,
|
featured: bool,
|
||||||
creator: str | None,
|
creator: str | None,
|
||||||
sorted_by: store_db.StoreAgentsSortOptions | None,
|
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None,
|
||||||
search_query: str | None,
|
search_query: str | None,
|
||||||
category: str | None,
|
category: str | None,
|
||||||
page: int,
|
page: int,
|
||||||
@@ -55,7 +57,7 @@ async def _get_cached_agent_details(
|
|||||||
async def _get_cached_store_creators(
|
async def _get_cached_store_creators(
|
||||||
featured: bool,
|
featured: bool,
|
||||||
search_query: str | None,
|
search_query: str | None,
|
||||||
sorted_by: store_db.StoreCreatorsSortOptions | None,
|
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None,
|
||||||
page: int,
|
page: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
):
|
):
|
||||||
@@ -73,4 +75,4 @@ async def _get_cached_store_creators(
|
|||||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||||
async def _get_cached_creator_details(username: str):
|
async def _get_cached_creator_details(username: str):
|
||||||
"""Cached helper to get creator details."""
|
"""Cached helper to get creator details."""
|
||||||
return await store_db.get_store_creator(username=username.lower())
|
return await store_db.get_store_creator_details(username=username.lower())
|
||||||
|
|||||||
@@ -9,26 +9,15 @@ import logging
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, get_args, get_origin
|
from typing import Any
|
||||||
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
from backend.blocks.llm import LlmModel
|
|
||||||
from backend.data.db import query_raw_with_schema
|
from backend.data.db import query_raw_with_schema
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _contains_type(annotation: Any, target: type) -> bool:
|
|
||||||
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
|
|
||||||
if annotation is target:
|
|
||||||
return True
|
|
||||||
origin = get_origin(annotation)
|
|
||||||
if origin is None:
|
|
||||||
return False
|
|
||||||
return any(_contains_type(arg, target) for arg in get_args(annotation))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ContentItem:
|
class ContentItem:
|
||||||
"""Represents a piece of content to be embedded."""
|
"""Represents a piece of content to be embedded."""
|
||||||
@@ -199,51 +188,45 @@ class BlockHandler(ContentHandler):
|
|||||||
try:
|
try:
|
||||||
block_instance = block_cls()
|
block_instance = block_cls()
|
||||||
|
|
||||||
|
# Skip disabled blocks - they shouldn't be indexed
|
||||||
if block_instance.disabled:
|
if block_instance.disabled:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Build searchable text from block metadata
|
# Build searchable text from block metadata
|
||||||
parts = []
|
parts = []
|
||||||
if block_instance.name:
|
if hasattr(block_instance, "name") and block_instance.name:
|
||||||
parts.append(block_instance.name)
|
parts.append(block_instance.name)
|
||||||
if block_instance.description:
|
if (
|
||||||
|
hasattr(block_instance, "description")
|
||||||
|
and block_instance.description
|
||||||
|
):
|
||||||
parts.append(block_instance.description)
|
parts.append(block_instance.description)
|
||||||
if block_instance.categories:
|
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||||
|
# Convert BlockCategory enum to strings
|
||||||
parts.append(
|
parts.append(
|
||||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add input schema field descriptions
|
# Add input/output schema info
|
||||||
block_input_fields = block_instance.input_schema.model_fields
|
if hasattr(block_instance, "input_schema"):
|
||||||
parts += [
|
schema = block_instance.input_schema
|
||||||
f"{field_name}: {field_info.description}"
|
if hasattr(schema, "model_json_schema"):
|
||||||
for field_name, field_info in block_input_fields.items()
|
schema_dict = schema.model_json_schema()
|
||||||
if field_info.description
|
if "properties" in schema_dict:
|
||||||
]
|
for prop_name, prop_info in schema_dict[
|
||||||
|
"properties"
|
||||||
|
].items():
|
||||||
|
if "description" in prop_info:
|
||||||
|
parts.append(
|
||||||
|
f"{prop_name}: {prop_info['description']}"
|
||||||
|
)
|
||||||
|
|
||||||
searchable_text = " ".join(parts)
|
searchable_text = " ".join(parts)
|
||||||
|
|
||||||
|
# Convert categories set of enums to list of strings for JSON serialization
|
||||||
|
categories = getattr(block_instance, "categories", set())
|
||||||
categories_list = (
|
categories_list = (
|
||||||
[cat.value for cat in block_instance.categories]
|
[cat.value for cat in categories] if categories else []
|
||||||
if block_instance.categories
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract provider names from credentials fields
|
|
||||||
credentials_info = (
|
|
||||||
block_instance.input_schema.get_credentials_fields_info()
|
|
||||||
)
|
|
||||||
is_integration = len(credentials_info) > 0
|
|
||||||
provider_names = [
|
|
||||||
provider.value.lower()
|
|
||||||
for info in credentials_info.values()
|
|
||||||
for provider in info.provider
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check if block has LlmModel field in input schema
|
|
||||||
has_llm_model_field = any(
|
|
||||||
_contains_type(field.annotation, LlmModel)
|
|
||||||
for field in block_instance.input_schema.model_fields.values()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
items.append(
|
items.append(
|
||||||
@@ -252,11 +235,8 @@ class BlockHandler(ContentHandler):
|
|||||||
content_type=ContentType.BLOCK,
|
content_type=ContentType.BLOCK,
|
||||||
searchable_text=searchable_text,
|
searchable_text=searchable_text,
|
||||||
metadata={
|
metadata={
|
||||||
"name": block_instance.name,
|
"name": getattr(block_instance, "name", ""),
|
||||||
"categories": categories_list,
|
"categories": categories_list,
|
||||||
"providers": provider_names,
|
|
||||||
"has_llm_model_field": has_llm_model_field,
|
|
||||||
"is_integration": is_integration,
|
|
||||||
},
|
},
|
||||||
user_id=None, # Blocks are public
|
user_id=None, # Blocks are public
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -82,10 +82,9 @@ async def test_block_handler_get_missing_items(mocker):
|
|||||||
mock_block_instance.description = "Performs calculations"
|
mock_block_instance.description = "Performs calculations"
|
||||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||||
mock_block_instance.disabled = False
|
mock_block_instance.disabled = False
|
||||||
mock_field = MagicMock()
|
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||||
mock_field.description = "Math expression to evaluate"
|
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||||
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
|
}
|
||||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
|
||||||
mock_block_class.return_value = mock_block_instance
|
mock_block_class.return_value = mock_block_instance
|
||||||
|
|
||||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||||
@@ -310,19 +309,19 @@ async def test_content_handlers_registry():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_block_handler_handles_empty_attributes():
|
async def test_block_handler_handles_missing_attributes():
|
||||||
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
|
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||||
handler = BlockHandler()
|
handler = BlockHandler()
|
||||||
|
|
||||||
# Mock block with empty values (all attributes exist but are falsy)
|
# Mock block with minimal attributes
|
||||||
mock_block_class = MagicMock()
|
mock_block_class = MagicMock()
|
||||||
mock_block_instance = MagicMock()
|
mock_block_instance = MagicMock()
|
||||||
mock_block_instance.name = "Minimal Block"
|
mock_block_instance.name = "Minimal Block"
|
||||||
mock_block_instance.disabled = False
|
mock_block_instance.disabled = False
|
||||||
mock_block_instance.description = ""
|
# No description, categories, or schema
|
||||||
mock_block_instance.categories = set()
|
del mock_block_instance.description
|
||||||
mock_block_instance.input_schema.model_fields = {}
|
del mock_block_instance.categories
|
||||||
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
|
del mock_block_instance.input_schema
|
||||||
mock_block_class.return_value = mock_block_instance
|
mock_block_class.return_value = mock_block_instance
|
||||||
|
|
||||||
mock_blocks = {"block-minimal": mock_block_class}
|
mock_blocks = {"block-minimal": mock_block_class}
|
||||||
@@ -353,8 +352,6 @@ async def test_block_handler_skips_failed_blocks():
|
|||||||
good_instance.description = "Works fine"
|
good_instance.description = "Works fine"
|
||||||
good_instance.categories = []
|
good_instance.categories = []
|
||||||
good_instance.disabled = False
|
good_instance.disabled = False
|
||||||
good_instance.input_schema.model_fields = {}
|
|
||||||
good_instance.input_schema.get_credentials_fields_info.return_value = {}
|
|
||||||
good_block.return_value = good_instance
|
good_block.return_value = good_instance
|
||||||
|
|
||||||
bad_block = MagicMock()
|
bad_block = MagicMock()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ async def test_get_store_agents(mocker):
|
|||||||
mock_agents = [
|
mock_agents = [
|
||||||
prisma.models.StoreAgent(
|
prisma.models.StoreAgent(
|
||||||
listing_id="test-id",
|
listing_id="test-id",
|
||||||
listing_version_id="version123",
|
storeListingVersionId="version123",
|
||||||
slug="test-agent",
|
slug="test-agent",
|
||||||
agent_name="Test Agent",
|
agent_name="Test Agent",
|
||||||
agent_video=None,
|
agent_video=None,
|
||||||
@@ -40,11 +40,11 @@ async def test_get_store_agents(mocker):
|
|||||||
runs=10,
|
runs=10,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
versions=["1.0"],
|
versions=["1.0"],
|
||||||
graph_id="test-graph-id",
|
agentGraphVersions=["1"],
|
||||||
graph_versions=["1"],
|
agentGraphId="test-graph-id",
|
||||||
updated_at=datetime.now(),
|
updated_at=datetime.now(),
|
||||||
is_available=False,
|
is_available=False,
|
||||||
use_for_onboarding=False,
|
useForOnboarding=False,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -68,10 +68,10 @@ async def test_get_store_agents(mocker):
|
|||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_store_agent_details(mocker):
|
async def test_get_store_agent_details(mocker):
|
||||||
# Mock data - StoreAgent view already contains the active version data
|
# Mock data
|
||||||
mock_agent = prisma.models.StoreAgent(
|
mock_agent = prisma.models.StoreAgent(
|
||||||
listing_id="test-id",
|
listing_id="test-id",
|
||||||
listing_version_id="version123",
|
storeListingVersionId="version123",
|
||||||
slug="test-agent",
|
slug="test-agent",
|
||||||
agent_name="Test Agent",
|
agent_name="Test Agent",
|
||||||
agent_video="video.mp4",
|
agent_video="video.mp4",
|
||||||
@@ -85,38 +85,102 @@ async def test_get_store_agent_details(mocker):
|
|||||||
runs=10,
|
runs=10,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
versions=["1.0"],
|
versions=["1.0"],
|
||||||
graph_id="test-graph-id",
|
agentGraphVersions=["1"],
|
||||||
graph_versions=["1"],
|
agentGraphId="test-graph-id",
|
||||||
updated_at=datetime.now(),
|
updated_at=datetime.now(),
|
||||||
is_available=True,
|
is_available=False,
|
||||||
use_for_onboarding=False,
|
useForOnboarding=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock StoreAgent prisma call
|
# Mock active version agent (what we want to return for active version)
|
||||||
|
mock_active_agent = prisma.models.StoreAgent(
|
||||||
|
listing_id="test-id",
|
||||||
|
storeListingVersionId="active-version-id",
|
||||||
|
slug="test-agent",
|
||||||
|
agent_name="Test Agent Active",
|
||||||
|
agent_video="active_video.mp4",
|
||||||
|
agent_image=["active_image.jpg"],
|
||||||
|
featured=False,
|
||||||
|
creator_username="creator",
|
||||||
|
creator_avatar="avatar.jpg",
|
||||||
|
sub_heading="Test heading active",
|
||||||
|
description="Test description active",
|
||||||
|
categories=["test"],
|
||||||
|
runs=15,
|
||||||
|
rating=4.8,
|
||||||
|
versions=["1.0", "2.0"],
|
||||||
|
agentGraphVersions=["1", "2"],
|
||||||
|
agentGraphId="test-graph-id-active",
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
is_available=True,
|
||||||
|
useForOnboarding=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a mock StoreListing result
|
||||||
|
mock_store_listing = mocker.MagicMock()
|
||||||
|
mock_store_listing.activeVersionId = "active-version-id"
|
||||||
|
mock_store_listing.hasApprovedVersion = True
|
||||||
|
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
||||||
|
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
||||||
|
|
||||||
|
# Mock StoreAgent prisma call - need to handle multiple calls
|
||||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
|
||||||
|
# Set up side_effect to return different results for different calls
|
||||||
|
def mock_find_first_side_effect(*args, **kwargs):
|
||||||
|
where_clause = kwargs.get("where", {})
|
||||||
|
if "storeListingVersionId" in where_clause:
|
||||||
|
# Second call for active version
|
||||||
|
return mock_active_agent
|
||||||
|
else:
|
||||||
|
# First call for initial lookup
|
||||||
|
return mock_agent
|
||||||
|
|
||||||
|
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
||||||
|
side_effect=mock_find_first_side_effect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock Profile prisma call
|
||||||
|
mock_profile = mocker.MagicMock()
|
||||||
|
mock_profile.userId = "user-id-123"
|
||||||
|
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||||
|
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||||
|
return_value=mock_profile
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock StoreListing prisma call
|
||||||
|
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||||
|
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||||
|
return_value=mock_store_listing
|
||||||
|
)
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.get_store_agent_details("creator", "test-agent")
|
result = await db.get_store_agent_details("creator", "test-agent")
|
||||||
|
|
||||||
# Verify results - constructed from the StoreAgent view
|
# Verify results - should use active version data
|
||||||
assert result.slug == "test-agent"
|
assert result.slug == "test-agent"
|
||||||
assert result.agent_name == "Test Agent"
|
assert result.agent_name == "Test Agent Active" # From active version
|
||||||
assert result.active_version_id == "version123"
|
assert result.active_version_id == "active-version-id"
|
||||||
assert result.has_approved_version is True
|
assert result.has_approved_version is True
|
||||||
assert result.store_listing_version_id == "version123"
|
assert (
|
||||||
assert result.graph_id == "test-graph-id"
|
result.store_listing_version_id == "active-version-id"
|
||||||
assert result.runs == 10
|
) # Should be active version ID
|
||||||
assert result.rating == 4.5
|
|
||||||
|
|
||||||
# Verify single StoreAgent lookup
|
# Verify mocks called correctly - now expecting 2 calls
|
||||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
assert mock_store_agent.return_value.find_first.call_count == 2
|
||||||
|
|
||||||
|
# Check the specific calls
|
||||||
|
calls = mock_store_agent.return_value.find_first.call_args_list
|
||||||
|
assert calls[0] == mocker.call(
|
||||||
where={"creator_username": "creator", "slug": "test-agent"}
|
where={"creator_username": "creator", "slug": "test-agent"}
|
||||||
)
|
)
|
||||||
|
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
||||||
|
|
||||||
|
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_store_creator(mocker):
|
async def test_get_store_creator_details(mocker):
|
||||||
# Mock data
|
# Mock data
|
||||||
mock_creator_data = prisma.models.Creator(
|
mock_creator_data = prisma.models.Creator(
|
||||||
name="Test Creator",
|
name="Test Creator",
|
||||||
@@ -138,7 +202,7 @@ async def test_get_store_creator(mocker):
|
|||||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.get_store_creator("creator")
|
result = await db.get_store_creator_details("creator")
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert result.username == "creator"
|
assert result.username == "creator"
|
||||||
@@ -154,110 +218,61 @@ async def test_get_store_creator(mocker):
|
|||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_create_store_submission(mocker):
|
async def test_create_store_submission(mocker):
|
||||||
now = datetime.now()
|
# Mock data
|
||||||
|
|
||||||
# Mock agent graph (with no pending submissions) and user with profile
|
|
||||||
mock_profile = prisma.models.Profile(
|
|
||||||
id="profile-id",
|
|
||||||
userId="user-id",
|
|
||||||
name="Test User",
|
|
||||||
username="testuser",
|
|
||||||
description="Test",
|
|
||||||
isFeatured=False,
|
|
||||||
links=[],
|
|
||||||
createdAt=now,
|
|
||||||
updatedAt=now,
|
|
||||||
)
|
|
||||||
mock_user = prisma.models.User(
|
|
||||||
id="user-id",
|
|
||||||
email="test@example.com",
|
|
||||||
createdAt=now,
|
|
||||||
updatedAt=now,
|
|
||||||
Profile=[mock_profile],
|
|
||||||
emailVerified=True,
|
|
||||||
metadata="{}", # type: ignore[reportArgumentType]
|
|
||||||
integrations="",
|
|
||||||
maxEmailsPerDay=1,
|
|
||||||
notifyOnAgentRun=True,
|
|
||||||
notifyOnZeroBalance=True,
|
|
||||||
notifyOnLowBalance=True,
|
|
||||||
notifyOnBlockExecutionFailed=True,
|
|
||||||
notifyOnContinuousAgentError=True,
|
|
||||||
notifyOnDailySummary=True,
|
|
||||||
notifyOnWeeklySummary=True,
|
|
||||||
notifyOnMonthlySummary=True,
|
|
||||||
notifyOnAgentApproved=True,
|
|
||||||
notifyOnAgentRejected=True,
|
|
||||||
timezone="Europe/Delft",
|
|
||||||
)
|
|
||||||
mock_agent = prisma.models.AgentGraph(
|
mock_agent = prisma.models.AgentGraph(
|
||||||
id="agent-id",
|
id="agent-id",
|
||||||
version=1,
|
version=1,
|
||||||
userId="user-id",
|
userId="user-id",
|
||||||
createdAt=now,
|
createdAt=datetime.now(),
|
||||||
isActive=True,
|
isActive=True,
|
||||||
StoreListingVersions=[],
|
|
||||||
User=mock_user,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock the created StoreListingVersion (returned by create)
|
mock_listing = prisma.models.StoreListing(
|
||||||
mock_store_listing_obj = prisma.models.StoreListing(
|
|
||||||
id="listing-id",
|
id="listing-id",
|
||||||
createdAt=now,
|
createdAt=datetime.now(),
|
||||||
updatedAt=now,
|
updatedAt=datetime.now(),
|
||||||
isDeleted=False,
|
isDeleted=False,
|
||||||
hasApprovedVersion=False,
|
hasApprovedVersion=False,
|
||||||
slug="test-agent",
|
slug="test-agent",
|
||||||
agentGraphId="agent-id",
|
agentGraphId="agent-id",
|
||||||
owningUserId="user-id",
|
|
||||||
useForOnboarding=False,
|
|
||||||
)
|
|
||||||
mock_version = prisma.models.StoreListingVersion(
|
|
||||||
id="version-id",
|
|
||||||
agentGraphId="agent-id",
|
|
||||||
agentGraphVersion=1,
|
agentGraphVersion=1,
|
||||||
name="Test Agent",
|
owningUserId="user-id",
|
||||||
description="Test description",
|
Versions=[
|
||||||
createdAt=now,
|
prisma.models.StoreListingVersion(
|
||||||
updatedAt=now,
|
id="version-id",
|
||||||
subHeading="",
|
agentGraphId="agent-id",
|
||||||
imageUrls=[],
|
agentGraphVersion=1,
|
||||||
categories=[],
|
name="Test Agent",
|
||||||
isFeatured=False,
|
description="Test description",
|
||||||
isDeleted=False,
|
createdAt=datetime.now(),
|
||||||
version=1,
|
updatedAt=datetime.now(),
|
||||||
storeListingId="listing-id",
|
subHeading="Test heading",
|
||||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
imageUrls=["image.jpg"],
|
||||||
isAvailable=True,
|
categories=["test"],
|
||||||
submittedAt=now,
|
isFeatured=False,
|
||||||
StoreListing=mock_store_listing_obj,
|
isDeleted=False,
|
||||||
|
version=1,
|
||||||
|
storeListingId="listing-id",
|
||||||
|
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||||
|
isAvailable=True,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
useForOnboarding=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock prisma calls
|
# Mock prisma calls
|
||||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||||
|
|
||||||
# Mock transaction context manager
|
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
|
||||||
mock_tx = mocker.MagicMock()
|
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||||
mocker.patch(
|
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
|
||||||
"backend.api.features.store.db.transaction",
|
|
||||||
return_value=mocker.AsyncMock(
|
|
||||||
__aenter__=mocker.AsyncMock(return_value=mock_tx),
|
|
||||||
__aexit__=mocker.AsyncMock(return_value=False),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_sl = mocker.patch("prisma.models.StoreListing.prisma")
|
|
||||||
mock_sl.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
mock_slv = mocker.patch("prisma.models.StoreListingVersion.prisma")
|
|
||||||
mock_slv.return_value.create = mocker.AsyncMock(return_value=mock_version)
|
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
result = await db.create_store_submission(
|
result = await db.create_store_submission(
|
||||||
user_id="user-id",
|
user_id="user-id",
|
||||||
graph_id="agent-id",
|
agent_id="agent-id",
|
||||||
graph_version=1,
|
agent_version=1,
|
||||||
slug="test-agent",
|
slug="test-agent",
|
||||||
name="Test Agent",
|
name="Test Agent",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
@@ -266,11 +281,11 @@ async def test_create_store_submission(mocker):
|
|||||||
# Verify results
|
# Verify results
|
||||||
assert result.name == "Test Agent"
|
assert result.name == "Test Agent"
|
||||||
assert result.description == "Test description"
|
assert result.description == "Test description"
|
||||||
assert result.listing_version_id == "version-id"
|
assert result.store_listing_version_id == "version-id"
|
||||||
|
|
||||||
# Verify mocks called correctly
|
# Verify mocks called correctly
|
||||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||||
mock_slv.return_value.create.assert_called_once()
|
mock_store_listing.return_value.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
@@ -303,6 +318,7 @@ async def test_update_profile(mocker):
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
links=["link1"],
|
links=["link1"],
|
||||||
avatar_url="avatar.jpg",
|
avatar_url="avatar.jpg",
|
||||||
|
is_featured=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
@@ -373,7 +389,7 @@ async def test_get_store_agents_with_search_and_filters_parameterized():
|
|||||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||||
category="AI'; DELETE FROM StoreAgent; --",
|
category="AI'; DELETE FROM StoreAgent; --",
|
||||||
featured=True,
|
featured=True,
|
||||||
sorted_by=db.StoreAgentsSortOptions.RATING,
|
sorted_by="rating",
|
||||||
page=1,
|
page=1,
|
||||||
page_size=20,
|
page_size=20,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,6 +57,12 @@ class StoreError(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AgentNotFoundError(NotFoundError):
|
||||||
|
"""Raised when an agent is not found"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CreatorNotFoundError(NotFoundError):
|
class CreatorNotFoundError(NotFoundError):
|
||||||
"""Raised when a creator is not found"""
|
"""Raised when a creator is not found"""
|
||||||
|
|
||||||
|
|||||||
@@ -568,7 +568,7 @@ async def hybrid_search(
|
|||||||
SELECT uce."contentId" as "storeListingVersionId"
|
SELECT uce."contentId" as "storeListingVersionId"
|
||||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||||
ON uce."contentId" = sa.listing_version_id
|
ON uce."contentId" = sa."storeListingVersionId"
|
||||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||||
AND uce."userId" IS NULL
|
AND uce."userId" IS NULL
|
||||||
AND uce.search @@ plainto_tsquery('english', {query_param})
|
AND uce.search @@ plainto_tsquery('english', {query_param})
|
||||||
@@ -582,7 +582,7 @@ async def hybrid_search(
|
|||||||
SELECT uce."contentId", uce.embedding
|
SELECT uce."contentId", uce.embedding
|
||||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||||
ON uce."contentId" = sa.listing_version_id
|
ON uce."contentId" = sa."storeListingVersionId"
|
||||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||||
AND uce."userId" IS NULL
|
AND uce."userId" IS NULL
|
||||||
AND {where_clause}
|
AND {where_clause}
|
||||||
@@ -605,7 +605,7 @@ async def hybrid_search(
|
|||||||
sa.featured,
|
sa.featured,
|
||||||
sa.is_available,
|
sa.is_available,
|
||||||
sa.updated_at,
|
sa.updated_at,
|
||||||
sa.graph_id,
|
sa."agentGraphId",
|
||||||
-- Searchable text for BM25 reranking
|
-- Searchable text for BM25 reranking
|
||||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||||
-- Semantic score
|
-- Semantic score
|
||||||
@@ -627,9 +627,9 @@ async def hybrid_search(
|
|||||||
sa.runs as popularity_raw
|
sa.runs as popularity_raw
|
||||||
FROM candidates c
|
FROM candidates c
|
||||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||||
ON c."storeListingVersionId" = sa.listing_version_id
|
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||||
ON sa.listing_version_id = uce."contentId"
|
ON sa."storeListingVersionId" = uce."contentId"
|
||||||
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||||
),
|
),
|
||||||
max_vals AS (
|
max_vals AS (
|
||||||
@@ -665,7 +665,7 @@ async def hybrid_search(
|
|||||||
featured,
|
featured,
|
||||||
is_available,
|
is_available,
|
||||||
updated_at,
|
updated_at,
|
||||||
graph_id,
|
"agentGraphId",
|
||||||
searchable_text,
|
searchable_text,
|
||||||
semantic_score,
|
semantic_score,
|
||||||
lexical_score,
|
lexical_score,
|
||||||
|
|||||||
@@ -1,14 +1,11 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from typing import TYPE_CHECKING, List, Self
|
from typing import List
|
||||||
|
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import prisma.models
|
|
||||||
|
|
||||||
|
|
||||||
class ChangelogEntry(pydantic.BaseModel):
|
class ChangelogEntry(pydantic.BaseModel):
|
||||||
version: str
|
version: str
|
||||||
@@ -16,9 +13,9 @@ class ChangelogEntry(pydantic.BaseModel):
|
|||||||
date: datetime.datetime
|
date: datetime.datetime
|
||||||
|
|
||||||
|
|
||||||
class MyUnpublishedAgent(pydantic.BaseModel):
|
class MyAgent(pydantic.BaseModel):
|
||||||
graph_id: str
|
agent_id: str
|
||||||
graph_version: int
|
agent_version: int
|
||||||
agent_name: str
|
agent_name: str
|
||||||
agent_image: str | None = None
|
agent_image: str | None = None
|
||||||
description: str
|
description: str
|
||||||
@@ -26,8 +23,8 @@ class MyUnpublishedAgent(pydantic.BaseModel):
|
|||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class MyUnpublishedAgentsResponse(pydantic.BaseModel):
|
class MyAgentsResponse(pydantic.BaseModel):
|
||||||
agents: list[MyUnpublishedAgent]
|
agents: list[MyAgent]
|
||||||
pagination: Pagination
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
@@ -43,21 +40,6 @@ class StoreAgent(pydantic.BaseModel):
|
|||||||
rating: float
|
rating: float
|
||||||
agent_graph_id: str
|
agent_graph_id: str
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgent":
|
|
||||||
return cls(
|
|
||||||
slug=agent.slug,
|
|
||||||
agent_name=agent.agent_name,
|
|
||||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
|
||||||
creator=agent.creator_username or "Needs Profile",
|
|
||||||
creator_avatar=agent.creator_avatar or "",
|
|
||||||
sub_heading=agent.sub_heading,
|
|
||||||
description=agent.description,
|
|
||||||
runs=agent.runs,
|
|
||||||
rating=agent.rating,
|
|
||||||
agent_graph_id=agent.graph_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StoreAgentsResponse(pydantic.BaseModel):
|
class StoreAgentsResponse(pydantic.BaseModel):
|
||||||
agents: list[StoreAgent]
|
agents: list[StoreAgent]
|
||||||
@@ -80,192 +62,81 @@ class StoreAgentDetails(pydantic.BaseModel):
|
|||||||
runs: int
|
runs: int
|
||||||
rating: float
|
rating: float
|
||||||
versions: list[str]
|
versions: list[str]
|
||||||
graph_id: str
|
agentGraphVersions: list[str]
|
||||||
graph_versions: list[str]
|
agentGraphId: str
|
||||||
last_updated: datetime.datetime
|
last_updated: datetime.datetime
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
|
||||||
active_version_id: str
|
active_version_id: str | None = None
|
||||||
has_approved_version: bool
|
has_approved_version: bool = False
|
||||||
|
|
||||||
# Optional changelog data when include_changelog=True
|
# Optional changelog data when include_changelog=True
|
||||||
changelog: list[ChangelogEntry] | None = None
|
changelog: list[ChangelogEntry] | None = None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, agent: "prisma.models.StoreAgent") -> "StoreAgentDetails":
|
|
||||||
return cls(
|
|
||||||
store_listing_version_id=agent.listing_version_id,
|
|
||||||
slug=agent.slug,
|
|
||||||
agent_name=agent.agent_name,
|
|
||||||
agent_video=agent.agent_video or "",
|
|
||||||
agent_output_demo=agent.agent_output_demo or "",
|
|
||||||
agent_image=agent.agent_image,
|
|
||||||
creator=agent.creator_username or "",
|
|
||||||
creator_avatar=agent.creator_avatar or "",
|
|
||||||
sub_heading=agent.sub_heading,
|
|
||||||
description=agent.description,
|
|
||||||
categories=agent.categories,
|
|
||||||
runs=agent.runs,
|
|
||||||
rating=agent.rating,
|
|
||||||
versions=agent.versions,
|
|
||||||
graph_id=agent.graph_id,
|
|
||||||
graph_versions=agent.graph_versions,
|
|
||||||
last_updated=agent.updated_at,
|
|
||||||
recommended_schedule_cron=agent.recommended_schedule_cron,
|
|
||||||
active_version_id=agent.listing_version_id,
|
|
||||||
has_approved_version=True, # StoreAgent view only has approved agents
|
|
||||||
)
|
|
||||||
|
|
||||||
|
class Creator(pydantic.BaseModel):
|
||||||
class Profile(pydantic.BaseModel):
|
|
||||||
"""Marketplace user profile (only attributes that the user can update)"""
|
|
||||||
|
|
||||||
username: str
|
|
||||||
name: str
|
name: str
|
||||||
|
username: str
|
||||||
description: str
|
description: str
|
||||||
avatar_url: str | None
|
avatar_url: str
|
||||||
links: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class ProfileDetails(Profile):
|
|
||||||
"""Marketplace user profile (including read-only fields)"""
|
|
||||||
|
|
||||||
is_featured: bool
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, profile: "prisma.models.Profile") -> "ProfileDetails":
|
|
||||||
return cls(
|
|
||||||
name=profile.name,
|
|
||||||
username=profile.username,
|
|
||||||
avatar_url=profile.avatarUrl,
|
|
||||||
description=profile.description,
|
|
||||||
links=profile.links,
|
|
||||||
is_featured=profile.isFeatured,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CreatorDetails(ProfileDetails):
|
|
||||||
"""Marketplace creator profile details, including aggregated stats"""
|
|
||||||
|
|
||||||
num_agents: int
|
num_agents: int
|
||||||
agent_runs: int
|
|
||||||
agent_rating: float
|
agent_rating: float
|
||||||
top_categories: list[str]
|
agent_runs: int
|
||||||
|
is_featured: bool
|
||||||
@classmethod
|
|
||||||
def from_db(cls, creator: "prisma.models.Creator") -> "CreatorDetails": # type: ignore[override]
|
|
||||||
return cls(
|
|
||||||
name=creator.name,
|
|
||||||
username=creator.username,
|
|
||||||
avatar_url=creator.avatar_url,
|
|
||||||
description=creator.description,
|
|
||||||
links=creator.links,
|
|
||||||
is_featured=creator.is_featured,
|
|
||||||
num_agents=creator.num_agents,
|
|
||||||
agent_runs=creator.agent_runs,
|
|
||||||
agent_rating=creator.agent_rating,
|
|
||||||
top_categories=creator.top_categories,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CreatorsResponse(pydantic.BaseModel):
|
class CreatorsResponse(pydantic.BaseModel):
|
||||||
creators: List[CreatorDetails]
|
creators: List[Creator]
|
||||||
pagination: Pagination
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
class StoreSubmission(pydantic.BaseModel):
|
class CreatorDetails(pydantic.BaseModel):
|
||||||
# From StoreListing:
|
name: str
|
||||||
listing_id: str
|
username: str
|
||||||
user_id: str
|
description: str
|
||||||
slug: str
|
links: list[str]
|
||||||
|
avatar_url: str
|
||||||
|
agent_rating: float
|
||||||
|
agent_runs: int
|
||||||
|
top_categories: list[str]
|
||||||
|
|
||||||
# From StoreListingVersion:
|
|
||||||
listing_version_id: str
|
class Profile(pydantic.BaseModel):
|
||||||
listing_version: int
|
name: str
|
||||||
graph_id: str
|
username: str
|
||||||
graph_version: int
|
description: str
|
||||||
|
links: list[str]
|
||||||
|
avatar_url: str
|
||||||
|
is_featured: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class StoreSubmission(pydantic.BaseModel):
|
||||||
|
listing_id: str
|
||||||
|
agent_id: str
|
||||||
|
agent_version: int
|
||||||
name: str
|
name: str
|
||||||
sub_heading: str
|
sub_heading: str
|
||||||
|
slug: str
|
||||||
description: str
|
description: str
|
||||||
instructions: str | None
|
instructions: str | None = None
|
||||||
categories: list[str]
|
|
||||||
image_urls: list[str]
|
image_urls: list[str]
|
||||||
video_url: str | None
|
date_submitted: datetime.datetime
|
||||||
agent_output_demo_url: str | None
|
|
||||||
|
|
||||||
submitted_at: datetime.datetime | None
|
|
||||||
changes_summary: str | None
|
|
||||||
status: prisma.enums.SubmissionStatus
|
status: prisma.enums.SubmissionStatus
|
||||||
reviewed_at: datetime.datetime | None = None
|
runs: int
|
||||||
|
rating: float
|
||||||
|
store_listing_version_id: str | None = None
|
||||||
|
version: int | None = None # Actual version number from the database
|
||||||
|
|
||||||
reviewer_id: str | None = None
|
reviewer_id: str | None = None
|
||||||
review_comments: str | None = None # External comments visible to creator
|
review_comments: str | None = None # External comments visible to creator
|
||||||
|
internal_comments: str | None = None # Private notes for admin use only
|
||||||
|
reviewed_at: datetime.datetime | None = None
|
||||||
|
changes_summary: str | None = None
|
||||||
|
|
||||||
# Aggregated from AgentGraphExecutions and StoreListingReviews:
|
# Additional fields for editing
|
||||||
run_count: int = 0
|
video_url: str | None = None
|
||||||
review_count: int = 0
|
agent_output_demo_url: str | None = None
|
||||||
review_avg_rating: float = 0.0
|
categories: list[str] = []
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
|
||||||
"""Construct from the StoreSubmission Prisma view."""
|
|
||||||
return cls(
|
|
||||||
listing_id=_sub.listing_id,
|
|
||||||
user_id=_sub.user_id,
|
|
||||||
slug=_sub.slug,
|
|
||||||
listing_version_id=_sub.listing_version_id,
|
|
||||||
listing_version=_sub.listing_version,
|
|
||||||
graph_id=_sub.graph_id,
|
|
||||||
graph_version=_sub.graph_version,
|
|
||||||
name=_sub.name,
|
|
||||||
sub_heading=_sub.sub_heading,
|
|
||||||
description=_sub.description,
|
|
||||||
instructions=_sub.instructions,
|
|
||||||
categories=_sub.categories,
|
|
||||||
image_urls=_sub.image_urls,
|
|
||||||
video_url=_sub.video_url,
|
|
||||||
agent_output_demo_url=_sub.agent_output_demo_url,
|
|
||||||
submitted_at=_sub.submitted_at,
|
|
||||||
changes_summary=_sub.changes_summary,
|
|
||||||
status=_sub.status,
|
|
||||||
reviewed_at=_sub.reviewed_at,
|
|
||||||
reviewer_id=_sub.reviewer_id,
|
|
||||||
review_comments=_sub.review_comments,
|
|
||||||
run_count=_sub.run_count,
|
|
||||||
review_count=_sub.review_count,
|
|
||||||
review_avg_rating=_sub.review_avg_rating,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
|
||||||
"""
|
|
||||||
Construct from the StoreListingVersion Prisma model (with StoreListing included)
|
|
||||||
"""
|
|
||||||
if not (_l := _lv.StoreListing):
|
|
||||||
raise ValueError("StoreListingVersion must have included StoreListing")
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
listing_id=_l.id,
|
|
||||||
user_id=_l.owningUserId,
|
|
||||||
slug=_l.slug,
|
|
||||||
listing_version_id=_lv.id,
|
|
||||||
listing_version=_lv.version,
|
|
||||||
graph_id=_lv.agentGraphId,
|
|
||||||
graph_version=_lv.agentGraphVersion,
|
|
||||||
name=_lv.name,
|
|
||||||
sub_heading=_lv.subHeading,
|
|
||||||
description=_lv.description,
|
|
||||||
instructions=_lv.instructions,
|
|
||||||
categories=_lv.categories,
|
|
||||||
image_urls=_lv.imageUrls,
|
|
||||||
video_url=_lv.videoUrl,
|
|
||||||
agent_output_demo_url=_lv.agentOutputDemoUrl,
|
|
||||||
submitted_at=_lv.submittedAt,
|
|
||||||
changes_summary=_lv.changesSummary,
|
|
||||||
status=_lv.submissionStatus,
|
|
||||||
reviewed_at=_lv.reviewedAt,
|
|
||||||
reviewer_id=_lv.reviewerId,
|
|
||||||
review_comments=_lv.reviewComments,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||||
@@ -273,12 +144,33 @@ class StoreSubmissionsResponse(pydantic.BaseModel):
|
|||||||
pagination: Pagination
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
|
class StoreListingWithVersions(pydantic.BaseModel):
|
||||||
|
"""A store listing with its version history"""
|
||||||
|
|
||||||
|
listing_id: str
|
||||||
|
slug: str
|
||||||
|
agent_id: str
|
||||||
|
agent_version: int
|
||||||
|
active_version_id: str | None = None
|
||||||
|
has_approved_version: bool = False
|
||||||
|
creator_email: str | None = None
|
||||||
|
latest_version: StoreSubmission | None = None
|
||||||
|
versions: list[StoreSubmission] = []
|
||||||
|
|
||||||
|
|
||||||
|
class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||||
|
"""Response model for listings with version history"""
|
||||||
|
|
||||||
|
listings: list[StoreListingWithVersions]
|
||||||
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||||
graph_id: str = pydantic.Field(
|
agent_id: str = pydantic.Field(
|
||||||
..., min_length=1, description="Graph ID cannot be empty"
|
..., min_length=1, description="Agent ID cannot be empty"
|
||||||
)
|
)
|
||||||
graph_version: int = pydantic.Field(
|
agent_version: int = pydantic.Field(
|
||||||
..., gt=0, description="Graph version must be greater than 0"
|
..., gt=0, description="Agent version must be greater than 0"
|
||||||
)
|
)
|
||||||
slug: str
|
slug: str
|
||||||
name: str
|
name: str
|
||||||
@@ -306,42 +198,12 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
|
|||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class StoreSubmissionAdminView(StoreSubmission):
|
class ProfileDetails(pydantic.BaseModel):
|
||||||
internal_comments: str | None # Private admin notes
|
name: str
|
||||||
|
username: str
|
||||||
@classmethod
|
description: str
|
||||||
def from_db(cls, _sub: "prisma.models.StoreSubmission") -> Self:
|
links: list[str]
|
||||||
return cls(
|
avatar_url: str | None = None
|
||||||
**StoreSubmission.from_db(_sub).model_dump(),
|
|
||||||
internal_comments=_sub.internal_comments,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_listing_version(cls, _lv: "prisma.models.StoreListingVersion") -> Self:
|
|
||||||
return cls(
|
|
||||||
**StoreSubmission.from_listing_version(_lv).model_dump(),
|
|
||||||
internal_comments=_lv.internalComments,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StoreListingWithVersionsAdminView(pydantic.BaseModel):
|
|
||||||
"""A store listing with its version history"""
|
|
||||||
|
|
||||||
listing_id: str
|
|
||||||
graph_id: str
|
|
||||||
slug: str
|
|
||||||
active_listing_version_id: str | None = None
|
|
||||||
has_approved_version: bool = False
|
|
||||||
creator_email: str | None = None
|
|
||||||
latest_version: StoreSubmissionAdminView | None = None
|
|
||||||
versions: list[StoreSubmissionAdminView] = []
|
|
||||||
|
|
||||||
|
|
||||||
class StoreListingsWithVersionsAdminViewResponse(pydantic.BaseModel):
|
|
||||||
"""Response model for listings with version history"""
|
|
||||||
|
|
||||||
listings: list[StoreListingWithVersionsAdminView]
|
|
||||||
pagination: Pagination
|
|
||||||
|
|
||||||
|
|
||||||
class StoreReview(pydantic.BaseModel):
|
class StoreReview(pydantic.BaseModel):
|
||||||
|
|||||||
@@ -0,0 +1,203 @@
|
|||||||
|
import datetime
|
||||||
|
|
||||||
|
import prisma.enums
|
||||||
|
|
||||||
|
from . import model as store_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_pagination():
|
||||||
|
pagination = store_model.Pagination(
|
||||||
|
total_items=100, total_pages=5, current_page=2, page_size=20
|
||||||
|
)
|
||||||
|
assert pagination.total_items == 100
|
||||||
|
assert pagination.total_pages == 5
|
||||||
|
assert pagination.current_page == 2
|
||||||
|
assert pagination.page_size == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_agent():
|
||||||
|
agent = store_model.StoreAgent(
|
||||||
|
slug="test-agent",
|
||||||
|
agent_name="Test Agent",
|
||||||
|
agent_image="test.jpg",
|
||||||
|
creator="creator1",
|
||||||
|
creator_avatar="avatar.jpg",
|
||||||
|
sub_heading="Test subheading",
|
||||||
|
description="Test description",
|
||||||
|
runs=50,
|
||||||
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-id",
|
||||||
|
)
|
||||||
|
assert agent.slug == "test-agent"
|
||||||
|
assert agent.agent_name == "Test Agent"
|
||||||
|
assert agent.runs == 50
|
||||||
|
assert agent.rating == 4.5
|
||||||
|
assert agent.agent_graph_id == "test-graph-id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_agents_response():
|
||||||
|
response = store_model.StoreAgentsResponse(
|
||||||
|
agents=[
|
||||||
|
store_model.StoreAgent(
|
||||||
|
slug="test-agent",
|
||||||
|
agent_name="Test Agent",
|
||||||
|
agent_image="test.jpg",
|
||||||
|
creator="creator1",
|
||||||
|
creator_avatar="avatar.jpg",
|
||||||
|
sub_heading="Test subheading",
|
||||||
|
description="Test description",
|
||||||
|
runs=50,
|
||||||
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-id",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
pagination=store_model.Pagination(
|
||||||
|
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert len(response.agents) == 1
|
||||||
|
assert response.pagination.total_items == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_agent_details():
|
||||||
|
details = store_model.StoreAgentDetails(
|
||||||
|
store_listing_version_id="version123",
|
||||||
|
slug="test-agent",
|
||||||
|
agent_name="Test Agent",
|
||||||
|
agent_video="video.mp4",
|
||||||
|
agent_output_demo="demo.mp4",
|
||||||
|
agent_image=["image1.jpg", "image2.jpg"],
|
||||||
|
creator="creator1",
|
||||||
|
creator_avatar="avatar.jpg",
|
||||||
|
sub_heading="Test subheading",
|
||||||
|
description="Test description",
|
||||||
|
categories=["cat1", "cat2"],
|
||||||
|
runs=50,
|
||||||
|
rating=4.5,
|
||||||
|
versions=["1.0", "2.0"],
|
||||||
|
agentGraphVersions=["1", "2"],
|
||||||
|
agentGraphId="test-graph-id",
|
||||||
|
last_updated=datetime.datetime.now(),
|
||||||
|
)
|
||||||
|
assert details.slug == "test-agent"
|
||||||
|
assert len(details.agent_image) == 2
|
||||||
|
assert len(details.categories) == 2
|
||||||
|
assert len(details.versions) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_creator():
|
||||||
|
creator = store_model.Creator(
|
||||||
|
agent_rating=4.8,
|
||||||
|
agent_runs=1000,
|
||||||
|
name="Test Creator",
|
||||||
|
username="creator1",
|
||||||
|
description="Test description",
|
||||||
|
avatar_url="avatar.jpg",
|
||||||
|
num_agents=5,
|
||||||
|
is_featured=False,
|
||||||
|
)
|
||||||
|
assert creator.name == "Test Creator"
|
||||||
|
assert creator.num_agents == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_creators_response():
|
||||||
|
response = store_model.CreatorsResponse(
|
||||||
|
creators=[
|
||||||
|
store_model.Creator(
|
||||||
|
agent_rating=4.8,
|
||||||
|
agent_runs=1000,
|
||||||
|
name="Test Creator",
|
||||||
|
username="creator1",
|
||||||
|
description="Test description",
|
||||||
|
avatar_url="avatar.jpg",
|
||||||
|
num_agents=5,
|
||||||
|
is_featured=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
pagination=store_model.Pagination(
|
||||||
|
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert len(response.creators) == 1
|
||||||
|
assert response.pagination.total_items == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_creator_details():
|
||||||
|
details = store_model.CreatorDetails(
|
||||||
|
name="Test Creator",
|
||||||
|
username="creator1",
|
||||||
|
description="Test description",
|
||||||
|
links=["link1.com", "link2.com"],
|
||||||
|
avatar_url="avatar.jpg",
|
||||||
|
agent_rating=4.8,
|
||||||
|
agent_runs=1000,
|
||||||
|
top_categories=["cat1", "cat2"],
|
||||||
|
)
|
||||||
|
assert details.name == "Test Creator"
|
||||||
|
assert len(details.links) == 2
|
||||||
|
assert details.agent_rating == 4.8
|
||||||
|
assert len(details.top_categories) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_submission():
|
||||||
|
submission = store_model.StoreSubmission(
|
||||||
|
listing_id="listing123",
|
||||||
|
agent_id="agent123",
|
||||||
|
agent_version=1,
|
||||||
|
sub_heading="Test subheading",
|
||||||
|
name="Test Agent",
|
||||||
|
slug="test-agent",
|
||||||
|
description="Test description",
|
||||||
|
image_urls=["image1.jpg", "image2.jpg"],
|
||||||
|
date_submitted=datetime.datetime(2023, 1, 1),
|
||||||
|
status=prisma.enums.SubmissionStatus.PENDING,
|
||||||
|
runs=50,
|
||||||
|
rating=4.5,
|
||||||
|
)
|
||||||
|
assert submission.name == "Test Agent"
|
||||||
|
assert len(submission.image_urls) == 2
|
||||||
|
assert submission.status == prisma.enums.SubmissionStatus.PENDING
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_submissions_response():
|
||||||
|
response = store_model.StoreSubmissionsResponse(
|
||||||
|
submissions=[
|
||||||
|
store_model.StoreSubmission(
|
||||||
|
listing_id="listing123",
|
||||||
|
agent_id="agent123",
|
||||||
|
agent_version=1,
|
||||||
|
sub_heading="Test subheading",
|
||||||
|
name="Test Agent",
|
||||||
|
slug="test-agent",
|
||||||
|
description="Test description",
|
||||||
|
image_urls=["image1.jpg"],
|
||||||
|
date_submitted=datetime.datetime(2023, 1, 1),
|
||||||
|
status=prisma.enums.SubmissionStatus.PENDING,
|
||||||
|
runs=50,
|
||||||
|
rating=4.5,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
pagination=store_model.Pagination(
|
||||||
|
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert len(response.submissions) == 1
|
||||||
|
assert response.pagination.total_items == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_submission_request():
|
||||||
|
request = store_model.StoreSubmissionRequest(
|
||||||
|
agent_id="agent123",
|
||||||
|
agent_version=1,
|
||||||
|
slug="test-agent",
|
||||||
|
name="Test Agent",
|
||||||
|
sub_heading="Test subheading",
|
||||||
|
video_url="video.mp4",
|
||||||
|
image_urls=["image1.jpg", "image2.jpg"],
|
||||||
|
description="Test description",
|
||||||
|
categories=["cat1", "cat2"],
|
||||||
|
)
|
||||||
|
assert request.agent_id == "agent123"
|
||||||
|
assert request.agent_version == 1
|
||||||
|
assert len(request.image_urls) == 2
|
||||||
|
assert len(request.categories) == 2
|
||||||
@@ -1,17 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import typing
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import autogpt_libs.auth
|
import autogpt_libs.auth
|
||||||
import fastapi
|
import fastapi
|
||||||
import fastapi.responses
|
import fastapi.responses
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
from fastapi import Query, Security
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
import backend.util.json
|
import backend.util.json
|
||||||
from backend.util.exceptions import NotFoundError
|
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
from . import cache as store_cache
|
from . import cache as store_cache
|
||||||
@@ -35,15 +34,22 @@ router = fastapi.APIRouter()
|
|||||||
"/profile",
|
"/profile",
|
||||||
summary="Get user profile",
|
summary="Get user profile",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
response_model=store_model.ProfileDetails,
|
||||||
)
|
)
|
||||||
async def get_profile(
|
async def get_profile(
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> store_model.ProfileDetails:
|
):
|
||||||
"""Get the profile details for the authenticated user."""
|
"""
|
||||||
|
Get the profile details for the authenticated user.
|
||||||
|
Cached for 1 hour per user.
|
||||||
|
"""
|
||||||
profile = await store_db.get_user_profile(user_id)
|
profile = await store_db.get_user_profile(user_id)
|
||||||
if profile is None:
|
if profile is None:
|
||||||
raise NotFoundError("User does not have a profile yet")
|
return fastapi.responses.JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={"detail": "Profile not found"},
|
||||||
|
)
|
||||||
return profile
|
return profile
|
||||||
|
|
||||||
|
|
||||||
@@ -51,17 +57,98 @@ async def get_profile(
|
|||||||
"/profile",
|
"/profile",
|
||||||
summary="Update user profile",
|
summary="Update user profile",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
response_model=store_model.CreatorDetails,
|
||||||
)
|
)
|
||||||
async def update_or_create_profile(
|
async def update_or_create_profile(
|
||||||
profile: store_model.Profile,
|
profile: store_model.Profile,
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> store_model.ProfileDetails:
|
):
|
||||||
"""Update the store profile for the authenticated user."""
|
"""
|
||||||
|
Update the store profile for the authenticated user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile (Profile): The updated profile details
|
||||||
|
user_id (str): ID of the authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CreatorDetails: The updated profile
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If there is an error updating the profile
|
||||||
|
"""
|
||||||
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
updated_profile = await store_db.update_profile(user_id=user_id, profile=profile)
|
||||||
return updated_profile
|
return updated_profile
|
||||||
|
|
||||||
|
|
||||||
|
##############################################
|
||||||
|
############### Agent Endpoints ##############
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents",
|
||||||
|
summary="List store agents",
|
||||||
|
tags=["store", "public"],
|
||||||
|
response_model=store_model.StoreAgentsResponse,
|
||||||
|
)
|
||||||
|
async def get_agents(
|
||||||
|
featured: bool = False,
|
||||||
|
creator: str | None = None,
|
||||||
|
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||||
|
search_query: str | None = None,
|
||||||
|
category: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||||
|
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||||
|
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||||
|
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
||||||
|
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||||
|
page (int, optional): Page number for pagination. Defaults to 1.
|
||||||
|
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If page or page_size are less than 1
|
||||||
|
|
||||||
|
Used for:
|
||||||
|
- Home Page Featured Agents
|
||||||
|
- Home Page Top Agents
|
||||||
|
- Search Results
|
||||||
|
- Agent Details - Other Agents By Creator
|
||||||
|
- Agent Details - Similar Agents
|
||||||
|
- Creator Details - Agents By Creator
|
||||||
|
"""
|
||||||
|
if page < 1:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=422, detail="Page must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
if page_size < 1:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=422, detail="Page size must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
agents = await store_cache._get_cached_store_agents(
|
||||||
|
featured=featured,
|
||||||
|
creator=creator,
|
||||||
|
sorted_by=sorted_by,
|
||||||
|
search_query=search_query,
|
||||||
|
category=category,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
############### Search Endpoints #############
|
############### Search Endpoints #############
|
||||||
##############################################
|
##############################################
|
||||||
@@ -71,30 +158,60 @@ async def update_or_create_profile(
|
|||||||
"/search",
|
"/search",
|
||||||
summary="Unified search across all content types",
|
summary="Unified search across all content types",
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
|
response_model=store_model.UnifiedSearchResponse,
|
||||||
)
|
)
|
||||||
async def unified_search(
|
async def unified_search(
|
||||||
query: str,
|
query: str,
|
||||||
content_types: list[prisma.enums.ContentType] | None = Query(
|
content_types: list[str] | None = fastapi.Query(
|
||||||
default=None,
|
default=None,
|
||||||
description="Content types to search. If not specified, searches all.",
|
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
|
||||||
),
|
),
|
||||||
page: int = Query(ge=1, default=1),
|
page: int = 1,
|
||||||
page_size: int = Query(ge=1, default=20),
|
page_size: int = 20,
|
||||||
user_id: str | None = Security(
|
user_id: str | None = fastapi.Security(
|
||||||
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
autogpt_libs.auth.get_optional_user_id, use_cache=False
|
||||||
),
|
),
|
||||||
) -> store_model.UnifiedSearchResponse:
|
):
|
||||||
"""
|
"""
|
||||||
Search across all content types (marketplace agents, blocks, documentation)
|
Search across all content types (store agents, blocks, documentation) using hybrid search.
|
||||||
using hybrid search.
|
|
||||||
|
|
||||||
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
Combines semantic (embedding-based) and lexical (text-based) search for best results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query string
|
||||||
|
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
|
||||||
|
page: Page number for pagination (default 1)
|
||||||
|
page_size: Number of results per page (default 20)
|
||||||
|
user_id: Optional authenticated user ID (for user-scoped content in future)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UnifiedSearchResponse: Paginated list of search results with relevance scores
|
||||||
"""
|
"""
|
||||||
|
if page < 1:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=422, detail="Page must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
if page_size < 1:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=422, detail="Page size must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert string content types to enum
|
||||||
|
content_type_enums: list[prisma.enums.ContentType] | None = None
|
||||||
|
if content_types:
|
||||||
|
try:
|
||||||
|
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
|
||||||
|
except ValueError as e:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
# Perform unified hybrid search
|
# Perform unified hybrid search
|
||||||
results, total = await store_hybrid_search.unified_hybrid_search(
|
results, total = await store_hybrid_search.unified_hybrid_search(
|
||||||
query=query,
|
query=query,
|
||||||
content_types=content_types,
|
content_types=content_type_enums,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
@@ -128,69 +245,22 @@ async def unified_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
|
||||||
############### Agent Endpoints ##############
|
|
||||||
##############################################
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/agents",
|
|
||||||
summary="List store agents",
|
|
||||||
tags=["store", "public"],
|
|
||||||
)
|
|
||||||
async def get_agents(
|
|
||||||
featured: bool = Query(
|
|
||||||
default=False, description="Filter to only show featured agents"
|
|
||||||
),
|
|
||||||
creator: str | None = Query(
|
|
||||||
default=None, description="Filter agents by creator username"
|
|
||||||
),
|
|
||||||
category: str | None = Query(default=None, description="Filter agents by category"),
|
|
||||||
search_query: str | None = Query(
|
|
||||||
default=None, description="Literal + semantic search on names and descriptions"
|
|
||||||
),
|
|
||||||
sorted_by: store_db.StoreAgentsSortOptions | None = Query(
|
|
||||||
default=None,
|
|
||||||
description="Property to sort results by. Ignored if search_query is provided.",
|
|
||||||
),
|
|
||||||
page: int = Query(ge=1, default=1),
|
|
||||||
page_size: int = Query(ge=1, default=20),
|
|
||||||
) -> store_model.StoreAgentsResponse:
|
|
||||||
"""
|
|
||||||
Get a paginated list of agents from the marketplace,
|
|
||||||
with optional filtering and sorting.
|
|
||||||
|
|
||||||
Used for:
|
|
||||||
- Home Page Featured Agents
|
|
||||||
- Home Page Top Agents
|
|
||||||
- Search Results
|
|
||||||
- Agent Details - Other Agents By Creator
|
|
||||||
- Agent Details - Similar Agents
|
|
||||||
- Creator Details - Agents By Creator
|
|
||||||
"""
|
|
||||||
agents = await store_cache._get_cached_store_agents(
|
|
||||||
featured=featured,
|
|
||||||
creator=creator,
|
|
||||||
sorted_by=sorted_by,
|
|
||||||
search_query=search_query,
|
|
||||||
category=category,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
return agents
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/agents/{username}/{agent_name}",
|
"/agents/{username}/{agent_name}",
|
||||||
summary="Get specific agent",
|
summary="Get specific agent",
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
|
response_model=store_model.StoreAgentDetails,
|
||||||
)
|
)
|
||||||
async def get_agent_by_name(
|
async def get_agent(
|
||||||
username: str,
|
username: str,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
include_changelog: bool = Query(default=False),
|
include_changelog: bool = fastapi.Query(default=False),
|
||||||
) -> store_model.StoreAgentDetails:
|
):
|
||||||
"""Get details of a marketplace agent"""
|
"""
|
||||||
|
This is only used on the AgentDetails Page.
|
||||||
|
|
||||||
|
It returns the store listing agents details.
|
||||||
|
"""
|
||||||
username = urllib.parse.unquote(username).lower()
|
username = urllib.parse.unquote(username).lower()
|
||||||
# URL decode the agent name since it comes from the URL path
|
# URL decode the agent name since it comes from the URL path
|
||||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||||
@@ -200,82 +270,76 @@ async def get_agent_by_name(
|
|||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/graph/{store_listing_version_id}",
|
||||||
|
summary="Get agent graph",
|
||||||
|
tags=["store"],
|
||||||
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
)
|
||||||
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||||
|
"""
|
||||||
|
Get Agent Graph from Store Listing Version ID.
|
||||||
|
"""
|
||||||
|
graph = await store_db.get_available_graph(store_listing_version_id)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents/{store_listing_version_id}",
|
||||||
|
summary="Get agent by version",
|
||||||
|
tags=["store"],
|
||||||
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
response_model=store_model.StoreAgentDetails,
|
||||||
|
)
|
||||||
|
async def get_store_agent(store_listing_version_id: str):
|
||||||
|
"""
|
||||||
|
Get Store Agent Details from Store Listing Version ID.
|
||||||
|
"""
|
||||||
|
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/agents/{username}/{agent_name}/review",
|
"/agents/{username}/{agent_name}/review",
|
||||||
summary="Create agent review",
|
summary="Create agent review",
|
||||||
tags=["store"],
|
tags=["store"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
response_model=store_model.StoreReview,
|
||||||
)
|
)
|
||||||
async def post_user_review_for_agent(
|
async def create_review(
|
||||||
username: str,
|
username: str,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
review: store_model.StoreReviewCreate,
|
review: store_model.StoreReviewCreate,
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> store_model.StoreReview:
|
):
|
||||||
"""Post a user review on a marketplace agent listing"""
|
"""
|
||||||
|
Create a review for a store agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: Creator's username
|
||||||
|
agent_name: Name/slug of the agent
|
||||||
|
review: Review details including score and optional comments
|
||||||
|
user_id: ID of authenticated user creating the review
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created review
|
||||||
|
"""
|
||||||
username = urllib.parse.unquote(username).lower()
|
username = urllib.parse.unquote(username).lower()
|
||||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||||
|
# Create the review
|
||||||
created_review = await store_db.create_store_review(
|
created_review = await store_db.create_store_review(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
store_listing_version_id=review.store_listing_version_id,
|
store_listing_version_id=review.store_listing_version_id,
|
||||||
score=review.score,
|
score=review.score,
|
||||||
comments=review.comments,
|
comments=review.comments,
|
||||||
)
|
)
|
||||||
|
|
||||||
return created_review
|
return created_review
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/listings/versions/{store_listing_version_id}",
|
|
||||||
summary="Get agent by version",
|
|
||||||
tags=["store"],
|
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
|
||||||
)
|
|
||||||
async def get_agent_by_listing_version(
|
|
||||||
store_listing_version_id: str,
|
|
||||||
) -> store_model.StoreAgentDetails:
|
|
||||||
agent = await store_db.get_store_agent_by_version_id(store_listing_version_id)
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/listings/versions/{store_listing_version_id}/graph",
|
|
||||||
summary="Get agent graph",
|
|
||||||
tags=["store"],
|
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
|
||||||
)
|
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
|
||||||
store_listing_version_id: str,
|
|
||||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
|
||||||
"""Get outline of graph belonging to a specific marketplace listing version"""
|
|
||||||
graph = await store_db.get_available_graph(store_listing_version_id)
|
|
||||||
return graph
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/listings/versions/{store_listing_version_id}/graph/download",
|
|
||||||
summary="Download agent file",
|
|
||||||
tags=["store", "public"],
|
|
||||||
)
|
|
||||||
async def download_agent_file(
|
|
||||||
store_listing_version_id: str,
|
|
||||||
) -> fastapi.responses.FileResponse:
|
|
||||||
"""Download agent graph file for a specific marketplace listing version"""
|
|
||||||
graph_data = await store_db.get_agent(store_listing_version_id)
|
|
||||||
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
|
||||||
|
|
||||||
# Sending graph as a stream (similar to marketplace v1)
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w", suffix=".json", delete=False
|
|
||||||
) as tmp_file:
|
|
||||||
tmp_file.write(backend.util.json.dumps(graph_data))
|
|
||||||
tmp_file.flush()
|
|
||||||
|
|
||||||
return fastapi.responses.FileResponse(
|
|
||||||
tmp_file.name, filename=file_name, media_type="application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
############# Creator Endpoints #############
|
############# Creator Endpoints #############
|
||||||
##############################################
|
##############################################
|
||||||
@@ -285,19 +349,37 @@ async def download_agent_file(
|
|||||||
"/creators",
|
"/creators",
|
||||||
summary="List store creators",
|
summary="List store creators",
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
|
response_model=store_model.CreatorsResponse,
|
||||||
)
|
)
|
||||||
async def get_creators(
|
async def get_creators(
|
||||||
featured: bool = Query(
|
featured: bool = False,
|
||||||
default=False, description="Filter to only show featured creators"
|
search_query: str | None = None,
|
||||||
),
|
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||||
search_query: str | None = Query(
|
page: int = 1,
|
||||||
default=None, description="Literal + semantic search on names and descriptions"
|
page_size: int = 20,
|
||||||
),
|
):
|
||||||
sorted_by: store_db.StoreCreatorsSortOptions | None = None,
|
"""
|
||||||
page: int = Query(ge=1, default=1),
|
This is needed for:
|
||||||
page_size: int = Query(ge=1, default=20),
|
- Home Page Featured Creators
|
||||||
) -> store_model.CreatorsResponse:
|
- Search Results Page
|
||||||
"""List or search marketplace creators"""
|
|
||||||
|
---
|
||||||
|
|
||||||
|
To support this functionality we need:
|
||||||
|
- featured: bool - to limit the list to just featured agents
|
||||||
|
- search_query: str - vector search based on the creators profile description.
|
||||||
|
- sorted_by: [agent_rating, agent_runs] -
|
||||||
|
"""
|
||||||
|
if page < 1:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=422, detail="Page must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
if page_size < 1:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=422, detail="Page size must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
creators = await store_cache._get_cached_store_creators(
|
creators = await store_cache._get_cached_store_creators(
|
||||||
featured=featured,
|
featured=featured,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
@@ -309,12 +391,18 @@ async def get_creators(
|
|||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/creators/{username}",
|
"/creator/{username}",
|
||||||
summary="Get creator details",
|
summary="Get creator details",
|
||||||
tags=["store", "public"],
|
tags=["store", "public"],
|
||||||
|
response_model=store_model.CreatorDetails,
|
||||||
)
|
)
|
||||||
async def get_creator(username: str) -> store_model.CreatorDetails:
|
async def get_creator(
|
||||||
"""Get details on a marketplace creator"""
|
username: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the details of a creator.
|
||||||
|
- Creator Details Page
|
||||||
|
"""
|
||||||
username = urllib.parse.unquote(username).lower()
|
username = urllib.parse.unquote(username).lower()
|
||||||
creator = await store_cache._get_cached_creator_details(username=username)
|
creator = await store_cache._get_cached_creator_details(username=username)
|
||||||
return creator
|
return creator
|
||||||
@@ -326,17 +414,20 @@ async def get_creator(username: str) -> store_model.CreatorDetails:
|
|||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/my-unpublished-agents",
|
"/myagents",
|
||||||
summary="Get my agents",
|
summary="Get my agents",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
response_model=store_model.MyAgentsResponse,
|
||||||
)
|
)
|
||||||
async def get_my_unpublished_agents(
|
async def get_my_agents(
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
page: int = Query(ge=1, default=1),
|
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
||||||
page_size: int = Query(ge=1, default=20),
|
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
||||||
) -> store_model.MyUnpublishedAgentsResponse:
|
):
|
||||||
"""List the authenticated user's unpublished agents"""
|
"""
|
||||||
|
Get user's own agents.
|
||||||
|
"""
|
||||||
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
agents = await store_db.get_my_agents(user_id, page=page, page_size=page_size)
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
@@ -345,17 +436,28 @@ async def get_my_unpublished_agents(
|
|||||||
"/submissions/{submission_id}",
|
"/submissions/{submission_id}",
|
||||||
summary="Delete store submission",
|
summary="Delete store submission",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
response_model=bool,
|
||||||
)
|
)
|
||||||
async def delete_submission(
|
async def delete_submission(
|
||||||
submission_id: str,
|
submission_id: str,
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> bool:
|
):
|
||||||
"""Delete a marketplace listing submission"""
|
"""
|
||||||
|
Delete a store listing submission.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): ID of the authenticated user
|
||||||
|
submission_id (str): ID of the submission to be deleted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the submission was successfully deleted, False otherwise
|
||||||
|
"""
|
||||||
result = await store_db.delete_store_submission(
|
result = await store_db.delete_store_submission(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
submission_id=submission_id,
|
submission_id=submission_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -363,14 +465,37 @@ async def delete_submission(
|
|||||||
"/submissions",
|
"/submissions",
|
||||||
summary="List my submissions",
|
summary="List my submissions",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
response_model=store_model.StoreSubmissionsResponse,
|
||||||
)
|
)
|
||||||
async def get_submissions(
|
async def get_submissions(
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
page: int = Query(ge=1, default=1),
|
page: int = 1,
|
||||||
page_size: int = Query(ge=1, default=20),
|
page_size: int = 20,
|
||||||
) -> store_model.StoreSubmissionsResponse:
|
):
|
||||||
"""List the authenticated user's marketplace listing submissions"""
|
"""
|
||||||
|
Get a paginated list of store submissions for the authenticated user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): ID of the authenticated user
|
||||||
|
page (int, optional): Page number for pagination. Defaults to 1.
|
||||||
|
page_size (int, optional): Number of submissions per page. Defaults to 20.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StoreListingsResponse: Paginated list of store submissions
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If page or page_size are less than 1
|
||||||
|
"""
|
||||||
|
if page < 1:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=422, detail="Page must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
if page_size < 1:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=422, detail="Page size must be greater than 0"
|
||||||
|
)
|
||||||
listings = await store_db.get_store_submissions(
|
listings = await store_db.get_store_submissions(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
page=page,
|
page=page,
|
||||||
@@ -383,17 +508,30 @@ async def get_submissions(
|
|||||||
"/submissions",
|
"/submissions",
|
||||||
summary="Create store submission",
|
summary="Create store submission",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
response_model=store_model.StoreSubmission,
|
||||||
)
|
)
|
||||||
async def create_submission(
|
async def create_submission(
|
||||||
submission_request: store_model.StoreSubmissionRequest,
|
submission_request: store_model.StoreSubmissionRequest,
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> store_model.StoreSubmission:
|
):
|
||||||
"""Submit a new marketplace listing for review"""
|
"""
|
||||||
|
Create a new store listing submission.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
submission_request (StoreSubmissionRequest): The submission details
|
||||||
|
user_id (str): ID of the authenticated user submitting the listing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StoreSubmission: The created store submission
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If there is an error creating the submission
|
||||||
|
"""
|
||||||
result = await store_db.create_store_submission(
|
result = await store_db.create_store_submission(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_id=submission_request.graph_id,
|
agent_id=submission_request.agent_id,
|
||||||
graph_version=submission_request.graph_version,
|
agent_version=submission_request.agent_version,
|
||||||
slug=submission_request.slug,
|
slug=submission_request.slug,
|
||||||
name=submission_request.name,
|
name=submission_request.name,
|
||||||
video_url=submission_request.video_url,
|
video_url=submission_request.video_url,
|
||||||
@@ -406,6 +544,7 @@ async def create_submission(
|
|||||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -413,14 +552,28 @@ async def create_submission(
|
|||||||
"/submissions/{store_listing_version_id}",
|
"/submissions/{store_listing_version_id}",
|
||||||
summary="Edit store submission",
|
summary="Edit store submission",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
|
response_model=store_model.StoreSubmission,
|
||||||
)
|
)
|
||||||
async def edit_submission(
|
async def edit_submission(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
submission_request: store_model.StoreSubmissionEditRequest,
|
submission_request: store_model.StoreSubmissionEditRequest,
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> store_model.StoreSubmission:
|
):
|
||||||
"""Update a pending marketplace listing submission"""
|
"""
|
||||||
|
Edit an existing store listing submission.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store_listing_version_id (str): ID of the store listing version to edit
|
||||||
|
submission_request (StoreSubmissionRequest): The updated submission details
|
||||||
|
user_id (str): ID of the authenticated user editing the listing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StoreSubmission: The updated store submission
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If there is an error editing the submission
|
||||||
|
"""
|
||||||
result = await store_db.edit_store_submission(
|
result = await store_db.edit_store_submission(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
store_listing_version_id=store_listing_version_id,
|
store_listing_version_id=store_listing_version_id,
|
||||||
@@ -435,6 +588,7 @@ async def edit_submission(
|
|||||||
changes_summary=submission_request.changes_summary,
|
changes_summary=submission_request.changes_summary,
|
||||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -442,61 +596,115 @@ async def edit_submission(
|
|||||||
"/submissions/media",
|
"/submissions/media",
|
||||||
summary="Upload submission media",
|
summary="Upload submission media",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
)
|
)
|
||||||
async def upload_submission_media(
|
async def upload_submission_media(
|
||||||
file: fastapi.UploadFile,
|
file: fastapi.UploadFile,
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> str:
|
):
|
||||||
"""Upload media for a marketplace listing submission"""
|
"""
|
||||||
|
Upload media (images/videos) for a store listing submission.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (UploadFile): The media file to upload
|
||||||
|
user_id (str): ID of the authenticated user uploading the media
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: URL of the uploaded media file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If there is an error uploading the media
|
||||||
|
"""
|
||||||
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
media_url = await store_media.upload_media(user_id=user_id, file=file)
|
||||||
return media_url
|
return media_url
|
||||||
|
|
||||||
|
|
||||||
class ImageURLResponse(BaseModel):
|
|
||||||
image_url: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/submissions/generate_image",
|
"/submissions/generate_image",
|
||||||
summary="Generate submission image",
|
summary="Generate submission image",
|
||||||
tags=["store", "private"],
|
tags=["store", "private"],
|
||||||
dependencies=[Security(autogpt_libs.auth.requires_user)],
|
dependencies=[fastapi.Security(autogpt_libs.auth.requires_user)],
|
||||||
)
|
)
|
||||||
async def generate_image(
|
async def generate_image(
|
||||||
graph_id: str,
|
agent_id: str,
|
||||||
user_id: str = Security(autogpt_libs.auth.get_user_id),
|
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||||
) -> ImageURLResponse:
|
) -> fastapi.responses.Response:
|
||||||
"""
|
"""
|
||||||
Generate an image for a marketplace listing submission based on the properties
|
Generate an image for a store listing submission.
|
||||||
of a given graph.
|
|
||||||
|
Args:
|
||||||
|
agent_id (str): ID of the agent to generate an image for
|
||||||
|
user_id (str): ID of the authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSONResponse: JSON containing the URL of the generated image
|
||||||
"""
|
"""
|
||||||
graph = await backend.data.graph.get_graph(
|
agent = await backend.data.graph.get_graph(
|
||||||
graph_id=graph_id, version=None, user_id=user_id
|
graph_id=agent_id, version=None, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not graph:
|
if not agent:
|
||||||
raise NotFoundError(f"Agent graph #{graph_id} not found")
|
raise fastapi.HTTPException(
|
||||||
|
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||||
|
)
|
||||||
# Use .jpeg here since we are generating JPEG images
|
# Use .jpeg here since we are generating JPEG images
|
||||||
filename = f"agent_{graph_id}.jpeg"
|
filename = f"agent_{agent_id}.jpeg"
|
||||||
|
|
||||||
existing_url = await store_media.check_media_exists(user_id, filename)
|
existing_url = await store_media.check_media_exists(user_id, filename)
|
||||||
if existing_url:
|
if existing_url:
|
||||||
logger.info(f"Using existing image for agent graph {graph_id}")
|
logger.info(f"Using existing image for agent {agent_id}")
|
||||||
return ImageURLResponse(image_url=existing_url)
|
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||||
# Generate agent image as JPEG
|
# Generate agent image as JPEG
|
||||||
image = await store_image_gen.generate_agent_image(agent=graph)
|
image = await store_image_gen.generate_agent_image(agent=agent)
|
||||||
|
|
||||||
# Create UploadFile with the correct filename and content_type
|
# Create UploadFile with the correct filename and content_type
|
||||||
image_file = fastapi.UploadFile(
|
image_file = fastapi.UploadFile(
|
||||||
file=image,
|
file=image,
|
||||||
filename=filename,
|
filename=filename,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_url = await store_media.upload_media(
|
image_url = await store_media.upload_media(
|
||||||
user_id=user_id, file=image_file, use_file_name=True
|
user_id=user_id, file=image_file, use_file_name=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageURLResponse(image_url=image_url)
|
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/download/agents/{store_listing_version_id}",
|
||||||
|
summary="Download agent file",
|
||||||
|
tags=["store", "public"],
|
||||||
|
)
|
||||||
|
async def download_agent_file(
|
||||||
|
store_listing_version_id: str = fastapi.Path(
|
||||||
|
..., description="The ID of the agent to download"
|
||||||
|
),
|
||||||
|
) -> fastapi.responses.FileResponse:
|
||||||
|
"""
|
||||||
|
Download the agent file by streaming its content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store_listing_version_id (str): The ID of the agent to download
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse: A streaming response containing the agent's graph data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||||
|
"""
|
||||||
|
graph_data = await store_db.get_agent(store_listing_version_id)
|
||||||
|
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
|
||||||
|
|
||||||
|
# Sending graph as a stream (similar to marketplace v1)
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", suffix=".json", delete=False
|
||||||
|
) as tmp_file:
|
||||||
|
tmp_file.write(backend.util.json.dumps(graph_data))
|
||||||
|
tmp_file.flush()
|
||||||
|
|
||||||
|
return fastapi.responses.FileResponse(
|
||||||
|
tmp_file.name, filename=file_name, media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import pytest
|
|||||||
import pytest_mock
|
import pytest_mock
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
|
|
||||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
|
||||||
|
|
||||||
from . import model as store_model
|
from . import model as store_model
|
||||||
from . import routes as store_routes
|
from . import routes as store_routes
|
||||||
|
|
||||||
@@ -198,7 +196,7 @@ def test_get_agents_sorted(
|
|||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
featured=False,
|
featured=False,
|
||||||
creators=None,
|
creators=None,
|
||||||
sorted_by=StoreAgentsSortOptions.RUNS,
|
sorted_by="runs",
|
||||||
search_query=None,
|
search_query=None,
|
||||||
category=None,
|
category=None,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -382,11 +380,9 @@ def test_get_agent_details(
|
|||||||
runs=100,
|
runs=100,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
versions=["1.0.0", "1.1.0"],
|
versions=["1.0.0", "1.1.0"],
|
||||||
graph_versions=["1", "2"],
|
agentGraphVersions=["1", "2"],
|
||||||
graph_id="test-graph-id",
|
agentGraphId="test-graph-id",
|
||||||
last_updated=FIXED_NOW,
|
last_updated=FIXED_NOW,
|
||||||
active_version_id="test-version-id",
|
|
||||||
has_approved_version=True,
|
|
||||||
)
|
)
|
||||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_agent_details")
|
||||||
mock_db_call.return_value = mocked_value
|
mock_db_call.return_value = mocked_value
|
||||||
@@ -439,17 +435,15 @@ def test_get_creators_pagination(
|
|||||||
) -> None:
|
) -> None:
|
||||||
mocked_value = store_model.CreatorsResponse(
|
mocked_value = store_model.CreatorsResponse(
|
||||||
creators=[
|
creators=[
|
||||||
store_model.CreatorDetails(
|
store_model.Creator(
|
||||||
name=f"Creator {i}",
|
name=f"Creator {i}",
|
||||||
username=f"creator{i}",
|
username=f"creator{i}",
|
||||||
avatar_url=f"avatar{i}.jpg",
|
|
||||||
description=f"Creator {i} description",
|
description=f"Creator {i} description",
|
||||||
links=[f"user{i}.link.com"],
|
avatar_url=f"avatar{i}.jpg",
|
||||||
is_featured=False,
|
|
||||||
num_agents=1,
|
num_agents=1,
|
||||||
agent_runs=100,
|
|
||||||
agent_rating=4.5,
|
agent_rating=4.5,
|
||||||
top_categories=["cat1", "cat2", "cat3"],
|
agent_runs=100,
|
||||||
|
is_featured=False,
|
||||||
)
|
)
|
||||||
for i in range(5)
|
for i in range(5)
|
||||||
],
|
],
|
||||||
@@ -502,19 +496,19 @@ def test_get_creator_details(
|
|||||||
mocked_value = store_model.CreatorDetails(
|
mocked_value = store_model.CreatorDetails(
|
||||||
name="Test User",
|
name="Test User",
|
||||||
username="creator1",
|
username="creator1",
|
||||||
avatar_url="avatar.jpg",
|
|
||||||
description="Test creator description",
|
description="Test creator description",
|
||||||
links=["link1.com", "link2.com"],
|
links=["link1.com", "link2.com"],
|
||||||
is_featured=True,
|
avatar_url="avatar.jpg",
|
||||||
num_agents=5,
|
|
||||||
agent_runs=1000,
|
|
||||||
agent_rating=4.8,
|
agent_rating=4.8,
|
||||||
|
agent_runs=1000,
|
||||||
top_categories=["category1", "category2"],
|
top_categories=["category1", "category2"],
|
||||||
)
|
)
|
||||||
mock_db_call = mocker.patch("backend.api.features.store.db.get_store_creator")
|
mock_db_call = mocker.patch(
|
||||||
|
"backend.api.features.store.db.get_store_creator_details"
|
||||||
|
)
|
||||||
mock_db_call.return_value = mocked_value
|
mock_db_call.return_value = mocked_value
|
||||||
|
|
||||||
response = client.get("/creators/creator1")
|
response = client.get("/creator/creator1")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
data = store_model.CreatorDetails.model_validate(response.json())
|
data = store_model.CreatorDetails.model_validate(response.json())
|
||||||
@@ -534,26 +528,19 @@ def test_get_submissions_success(
|
|||||||
submissions=[
|
submissions=[
|
||||||
store_model.StoreSubmission(
|
store_model.StoreSubmission(
|
||||||
listing_id="test-listing-id",
|
listing_id="test-listing-id",
|
||||||
user_id="test-user-id",
|
|
||||||
slug="test-agent",
|
|
||||||
listing_version_id="test-version-id",
|
|
||||||
listing_version=1,
|
|
||||||
graph_id="test-agent-id",
|
|
||||||
graph_version=1,
|
|
||||||
name="Test Agent",
|
name="Test Agent",
|
||||||
sub_heading="Test agent subheading",
|
|
||||||
description="Test agent description",
|
description="Test agent description",
|
||||||
instructions="Click the button!",
|
|
||||||
categories=["test-category"],
|
|
||||||
image_urls=["test.jpg"],
|
image_urls=["test.jpg"],
|
||||||
video_url="test.mp4",
|
date_submitted=FIXED_NOW,
|
||||||
agent_output_demo_url="demo_video.mp4",
|
|
||||||
submitted_at=FIXED_NOW,
|
|
||||||
changes_summary="Initial Submission",
|
|
||||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||||
run_count=50,
|
runs=50,
|
||||||
review_count=5,
|
rating=4.2,
|
||||||
review_avg_rating=4.2,
|
agent_id="test-agent-id",
|
||||||
|
agent_version=1,
|
||||||
|
sub_heading="Test agent subheading",
|
||||||
|
slug="test-agent",
|
||||||
|
video_url="test.mp4",
|
||||||
|
categories=["test-category"],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import pytest
|
|||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
from . import cache as store_cache
|
from . import cache as store_cache
|
||||||
from .db import StoreAgentsSortOptions
|
|
||||||
from .model import StoreAgent, StoreAgentsResponse
|
from .model import StoreAgent, StoreAgentsResponse
|
||||||
|
|
||||||
|
|
||||||
@@ -216,7 +215,7 @@ class TestCacheDeletion:
|
|||||||
await store_cache._get_cached_store_agents(
|
await store_cache._get_cached_store_agents(
|
||||||
featured=True,
|
featured=True,
|
||||||
creator="testuser",
|
creator="testuser",
|
||||||
sorted_by=StoreAgentsSortOptions.RATING,
|
sorted_by="rating",
|
||||||
search_query="AI assistant",
|
search_query="AI assistant",
|
||||||
category="productivity",
|
category="productivity",
|
||||||
page=2,
|
page=2,
|
||||||
@@ -228,7 +227,7 @@ class TestCacheDeletion:
|
|||||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||||
featured=True,
|
featured=True,
|
||||||
creator="testuser",
|
creator="testuser",
|
||||||
sorted_by=StoreAgentsSortOptions.RATING,
|
sorted_by="rating",
|
||||||
search_query="AI assistant",
|
search_query="AI assistant",
|
||||||
category="productivity",
|
category="productivity",
|
||||||
page=2,
|
page=2,
|
||||||
@@ -240,7 +239,7 @@ class TestCacheDeletion:
|
|||||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||||
featured=True,
|
featured=True,
|
||||||
creator="testuser",
|
creator="testuser",
|
||||||
sorted_by=StoreAgentsSortOptions.RATING,
|
sorted_by="rating",
|
||||||
search_query="AI assistant",
|
search_query="AI assistant",
|
||||||
category="productivity",
|
category="productivity",
|
||||||
page=2,
|
page=2,
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ from backend.data.credit import (
|
|||||||
set_auto_top_up,
|
set_auto_top_up,
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphSettings
|
from backend.data.graph import GraphSettings
|
||||||
from backend.data.invited_user import get_or_activate_user
|
|
||||||
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
from backend.data.model import CredentialsMetaInput, UserOnboarding
|
||||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||||
from backend.data.onboarding import (
|
from backend.data.onboarding import (
|
||||||
@@ -71,6 +70,7 @@ from backend.data.onboarding import (
|
|||||||
update_user_onboarding,
|
update_user_onboarding,
|
||||||
)
|
)
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
|
get_or_create_user,
|
||||||
get_user_by_id,
|
get_user_by_id,
|
||||||
get_user_notification_preference,
|
get_user_notification_preference,
|
||||||
update_user_email,
|
update_user_email,
|
||||||
@@ -126,9 +126,6 @@ v1_router = APIRouter()
|
|||||||
########################################################
|
########################################################
|
||||||
|
|
||||||
|
|
||||||
_tally_background_tasks: set[asyncio.Task] = set()
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
"/auth/user",
|
"/auth/user",
|
||||||
summary="Get or create user",
|
summary="Get or create user",
|
||||||
@@ -136,23 +133,7 @@ _tally_background_tasks: set[asyncio.Task] = set()
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
||||||
user = await get_or_activate_user(user_data)
|
user = await get_or_create_user(user_data)
|
||||||
|
|
||||||
# Fire-and-forget: backfill Tally understanding when invite pre-seeding did
|
|
||||||
# not produce a stored result before first activation.
|
|
||||||
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
|
|
||||||
if age_seconds < 30:
|
|
||||||
try:
|
|
||||||
from backend.data.tally import populate_understanding_from_tally
|
|
||||||
|
|
||||||
task = asyncio.create_task(
|
|
||||||
populate_understanding_from_tally(user.id, user.email)
|
|
||||||
)
|
|
||||||
_tally_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_tally_background_tasks.discard)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Failed to start Tally population task", exc_info=True)
|
|
||||||
|
|
||||||
return user.model_dump()
|
return user.model_dump()
|
||||||
|
|
||||||
|
|
||||||
@@ -163,8 +144,7 @@ async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def update_user_email_route(
|
async def update_user_email_route(
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)], email: str = Body(...)
|
||||||
email: str = Body(...),
|
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
await update_user_email(user_id, email)
|
await update_user_email(user_id, email)
|
||||||
|
|
||||||
@@ -178,16 +158,10 @@ async def update_user_email_route(
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def get_user_timezone_route(
|
async def get_user_timezone_route(
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_data: dict = Security(get_jwt_payload),
|
||||||
) -> TimezoneResponse:
|
) -> TimezoneResponse:
|
||||||
"""Get user timezone setting."""
|
"""Get user timezone setting."""
|
||||||
try:
|
user = await get_or_create_user(user_data)
|
||||||
user = await get_user_by_id(user_id)
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTP_404_NOT_FOUND,
|
|
||||||
detail="User not found. Please complete activation via /auth/user first.",
|
|
||||||
)
|
|
||||||
return TimezoneResponse(timezone=user.timezone)
|
return TimezoneResponse(timezone=user.timezone)
|
||||||
|
|
||||||
|
|
||||||
@@ -198,8 +172,7 @@ async def get_user_timezone_route(
|
|||||||
dependencies=[Security(requires_user)],
|
dependencies=[Security(requires_user)],
|
||||||
)
|
)
|
||||||
async def update_user_timezone_route(
|
async def update_user_timezone_route(
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||||
request: UpdateTimezoneRequest,
|
|
||||||
) -> TimezoneResponse:
|
) -> TimezoneResponse:
|
||||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||||
user = await update_user_timezone(user_id, str(request.timezone))
|
user = await update_user_timezone(user_id, str(request.timezone))
|
||||||
@@ -455,6 +428,7 @@ async def execute_graph_block(
|
|||||||
async def upload_file(
|
async def upload_file(
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
|
provider: str = "gcs",
|
||||||
expiration_hours: int = 24,
|
expiration_hours: int = 24,
|
||||||
) -> UploadFileResponse:
|
) -> UploadFileResponse:
|
||||||
"""
|
"""
|
||||||
@@ -517,6 +491,7 @@ async def upload_file(
|
|||||||
storage_path = await cloud_storage.store_file(
|
storage_path = await cloud_storage.store_file(
|
||||||
content=content,
|
content=content,
|
||||||
filename=file_name,
|
filename=file_name,
|
||||||
|
provider=provider,
|
||||||
expiration_hours=expiration_hours,
|
expiration_hours=expiration_hours,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
@@ -43,7 +43,6 @@ def test_get_or_create_user_route(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test get or create user endpoint"""
|
"""Test get or create user endpoint"""
|
||||||
mock_user = Mock()
|
mock_user = Mock()
|
||||||
mock_user.created_at = datetime.now(timezone.utc)
|
|
||||||
mock_user.model_dump.return_value = {
|
mock_user.model_dump.return_value = {
|
||||||
"id": test_user_id,
|
"id": test_user_id,
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
@@ -51,7 +50,7 @@ def test_get_or_create_user_route(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"backend.api.features.v1.get_or_activate_user",
|
"backend.api.features.v1.get_or_create_user",
|
||||||
return_value=mock_user,
|
return_value=mock_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -515,6 +514,7 @@ async def test_upload_file_success(test_user_id: str):
|
|||||||
result = await upload_file(
|
result = await upload_file(
|
||||||
file=upload_file_mock,
|
file=upload_file_mock,
|
||||||
user_id=test_user_id,
|
user_id=test_user_id,
|
||||||
|
provider="gcs",
|
||||||
expiration_hours=24,
|
expiration_hours=24,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -532,6 +532,7 @@ async def test_upload_file_success(test_user_id: str):
|
|||||||
mock_handler.store_file.assert_called_once_with(
|
mock_handler.store_file.assert_called_once_with(
|
||||||
content=file_content,
|
content=file_content,
|
||||||
filename="test.txt",
|
filename="test.txt",
|
||||||
|
provider="gcs",
|
||||||
expiration_hours=24,
|
expiration_hours=24,
|
||||||
user_id=test_user_id,
|
user_id=test_user_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,29 +3,15 @@ Workspace API routes for managing user file storage.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
from fastapi import Query, UploadFile
|
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.workspace import (
|
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
|
||||||
WorkspaceFile,
|
|
||||||
count_workspace_files,
|
|
||||||
get_or_create_workspace,
|
|
||||||
get_workspace,
|
|
||||||
get_workspace_file,
|
|
||||||
get_workspace_total_size,
|
|
||||||
soft_delete_workspace_file,
|
|
||||||
)
|
|
||||||
from backend.util.settings import Config
|
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
|
||||||
from backend.util.workspace import WorkspaceManager
|
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
|
||||||
@@ -112,25 +98,6 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
class UploadFileResponse(BaseModel):
|
|
||||||
file_id: str
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
mime_type: str
|
|
||||||
size_bytes: int
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteFileResponse(BaseModel):
|
|
||||||
deleted: bool
|
|
||||||
|
|
||||||
|
|
||||||
class StorageUsageResponse(BaseModel):
|
|
||||||
used_bytes: int
|
|
||||||
limit_bytes: int
|
|
||||||
used_percent: float
|
|
||||||
file_count: int
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/files/{file_id}/download",
|
"/files/{file_id}/download",
|
||||||
summary="Download file by ID",
|
summary="Download file by ID",
|
||||||
@@ -153,148 +120,3 @@ async def download_file(
|
|||||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
return await _create_file_download_response(file)
|
return await _create_file_download_response(file)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
|
||||||
"/files/{file_id}",
|
|
||||||
summary="Delete a workspace file",
|
|
||||||
)
|
|
||||||
async def delete_workspace_file(
|
|
||||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
|
||||||
file_id: str,
|
|
||||||
) -> DeleteFileResponse:
|
|
||||||
"""
|
|
||||||
Soft-delete a workspace file and attempt to remove it from storage.
|
|
||||||
|
|
||||||
Used when a user clears a file input in the builder.
|
|
||||||
"""
|
|
||||||
workspace = await get_workspace(user_id)
|
|
||||||
if workspace is None:
|
|
||||||
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
|
||||||
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id)
|
|
||||||
deleted = await manager.delete_file(file_id)
|
|
||||||
if not deleted:
|
|
||||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
|
||||||
|
|
||||||
return DeleteFileResponse(deleted=True)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/files/upload",
|
|
||||||
summary="Upload file to workspace",
|
|
||||||
)
|
|
||||||
async def upload_file(
|
|
||||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
|
||||||
file: UploadFile,
|
|
||||||
session_id: str | None = Query(default=None),
|
|
||||||
) -> UploadFileResponse:
|
|
||||||
"""
|
|
||||||
Upload a file to the user's workspace.
|
|
||||||
|
|
||||||
Files are stored in session-scoped paths when session_id is provided,
|
|
||||||
so the agent's session-scoped tools can discover them automatically.
|
|
||||||
"""
|
|
||||||
config = Config()
|
|
||||||
|
|
||||||
# Sanitize filename — strip any directory components
|
|
||||||
filename = os.path.basename(file.filename or "upload") or "upload"
|
|
||||||
|
|
||||||
# Read file content with early abort on size limit
|
|
||||||
max_file_bytes = config.max_file_size_mb * 1024 * 1024
|
|
||||||
chunks: list[bytes] = []
|
|
||||||
total_size = 0
|
|
||||||
while chunk := await file.read(64 * 1024): # 64KB chunks
|
|
||||||
total_size += len(chunk)
|
|
||||||
if total_size > max_file_bytes:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
|
|
||||||
)
|
|
||||||
chunks.append(chunk)
|
|
||||||
content = b"".join(chunks)
|
|
||||||
|
|
||||||
# Get or create workspace
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
|
|
||||||
# Pre-write storage cap check (soft check — final enforcement is post-write)
|
|
||||||
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
|
||||||
current_usage = await get_workspace_total_size(workspace.id)
|
|
||||||
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
|
|
||||||
used_percent = (current_usage / storage_limit_bytes) * 100
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail={
|
|
||||||
"message": "Storage limit exceeded",
|
|
||||||
"used_bytes": current_usage,
|
|
||||||
"limit_bytes": storage_limit_bytes,
|
|
||||||
"used_percent": round(used_percent, 1),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Warn at 80% usage
|
|
||||||
if (
|
|
||||||
storage_limit_bytes
|
|
||||||
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
|
|
||||||
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Virus scan
|
|
||||||
await scan_content_safe(content, filename=filename)
|
|
||||||
|
|
||||||
# Write file via WorkspaceManager
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
|
||||||
try:
|
|
||||||
workspace_file = await manager.write_file(content, filename)
|
|
||||||
except ValueError as e:
|
|
||||||
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
|
|
||||||
|
|
||||||
# Post-write storage check — eliminates TOCTOU race on the quota.
|
|
||||||
# If a concurrent upload pushed us over the limit, undo this write.
|
|
||||||
new_total = await get_workspace_total_size(workspace.id)
|
|
||||||
if storage_limit_bytes and new_total > storage_limit_bytes:
|
|
||||||
await soft_delete_workspace_file(workspace_file.id, workspace.id)
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail={
|
|
||||||
"message": "Storage limit exceeded (concurrent upload)",
|
|
||||||
"used_bytes": new_total,
|
|
||||||
"limit_bytes": storage_limit_bytes,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return UploadFileResponse(
|
|
||||||
file_id=workspace_file.id,
|
|
||||||
name=workspace_file.name,
|
|
||||||
path=workspace_file.path,
|
|
||||||
mime_type=workspace_file.mime_type,
|
|
||||||
size_bytes=workspace_file.size_bytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/storage/usage",
|
|
||||||
summary="Get workspace storage usage",
|
|
||||||
)
|
|
||||||
async def get_storage_usage(
|
|
||||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
|
||||||
) -> StorageUsageResponse:
|
|
||||||
"""
|
|
||||||
Get storage usage information for the user's workspace.
|
|
||||||
"""
|
|
||||||
config = Config()
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
|
|
||||||
used_bytes = await get_workspace_total_size(workspace.id)
|
|
||||||
file_count = await count_workspace_files(workspace.id)
|
|
||||||
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
|
|
||||||
|
|
||||||
return StorageUsageResponse(
|
|
||||||
used_bytes=used_bytes,
|
|
||||||
limit_bytes=limit_bytes,
|
|
||||||
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
|
|
||||||
file_count=file_count,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,359 +0,0 @@
|
|||||||
"""Tests for workspace file upload and download routes."""
|
|
||||||
|
|
||||||
import io
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import fastapi.testclient
|
|
||||||
import pytest
|
|
||||||
import pytest_mock
|
|
||||||
|
|
||||||
from backend.api.features.workspace import routes as workspace_routes
|
|
||||||
from backend.data.workspace import WorkspaceFile
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(workspace_routes.router)
|
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(ValueError)
|
|
||||||
async def _value_error_handler(
|
|
||||||
request: fastapi.Request, exc: ValueError
|
|
||||||
) -> fastapi.responses.JSONResponse:
|
|
||||||
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
|
|
||||||
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
|
|
||||||
|
|
||||||
|
|
||||||
client = fastapi.testclient.TestClient(app)
|
|
||||||
|
|
||||||
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
|
||||||
|
|
||||||
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
|
|
||||||
|
|
||||||
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
|
||||||
|
|
||||||
MOCK_FILE = WorkspaceFile(
|
|
||||||
id="file-aaa-bbb",
|
|
||||||
workspace_id="ws-1",
|
|
||||||
created_at=_NOW,
|
|
||||||
updated_at=_NOW,
|
|
||||||
name="hello.txt",
|
|
||||||
path="/session/hello.txt",
|
|
||||||
mime_type="text/plain",
|
|
||||||
size_bytes=13,
|
|
||||||
storage_path="local://hello.txt",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_app_auth(mock_jwt_user):
|
|
||||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
|
||||||
|
|
||||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
|
||||||
yield
|
|
||||||
app.dependency_overrides.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _upload(
|
|
||||||
filename: str = "hello.txt",
|
|
||||||
content: bytes = b"Hello, world!",
|
|
||||||
content_type: str = "text/plain",
|
|
||||||
):
|
|
||||||
"""Helper to POST a file upload."""
|
|
||||||
return client.post(
|
|
||||||
"/files/upload?session_id=sess-1",
|
|
||||||
files={"file": (filename, io.BytesIO(content), content_type)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Happy path ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
|
||||||
return_value=0,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.scan_content_safe",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mock_manager = mocker.MagicMock()
|
|
||||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
|
||||||
return_value=mock_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = _upload()
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["file_id"] == "file-aaa-bbb"
|
|
||||||
assert data["name"] == "hello.txt"
|
|
||||||
assert data["size_bytes"] == 13
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Per-file size limit ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Files larger than max_file_size_mb should be rejected with 413."""
|
|
||||||
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
|
|
||||||
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
|
|
||||||
cfg.return_value.max_workspace_storage_mb = 500
|
|
||||||
|
|
||||||
response = _upload(content=b"x" * 1024)
|
|
||||||
assert response.status_code == 413
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Storage quota exceeded ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
# Current usage already at limit
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
|
||||||
return_value=500 * 1024 * 1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = _upload()
|
|
||||||
assert response.status_code == 413
|
|
||||||
assert "Storage limit exceeded" in response.text
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Post-write quota race (B2) ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
|
|
||||||
"""If a concurrent upload tips the total over the limit after write,
|
|
||||||
the file should be soft-deleted and 413 returned."""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
# Pre-write check passes (under limit), but post-write check fails
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
|
||||||
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.scan_content_safe",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mock_manager = mocker.MagicMock()
|
|
||||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
|
||||||
return_value=mock_manager,
|
|
||||||
)
|
|
||||||
mock_delete = mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.soft_delete_workspace_file",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = _upload()
|
|
||||||
assert response.status_code == 413
|
|
||||||
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Any extension accepted (no allowlist) ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Any file extension should be accepted — ClamAV is the security layer."""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
|
||||||
return_value=0,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.scan_content_safe",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mock_manager = mocker.MagicMock()
|
|
||||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
|
||||||
return_value=mock_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = _upload(filename="data.xyz", content=b"arbitrary")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Virus scan rejection ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Files flagged by ClamAV should be rejected and never written to storage."""
|
|
||||||
from backend.api.features.store.exceptions import VirusDetectedError
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
|
||||||
return_value=0,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.scan_content_safe",
|
|
||||||
side_effect=VirusDetectedError("Eicar-Test-Signature"),
|
|
||||||
)
|
|
||||||
mock_manager = mocker.MagicMock()
|
|
||||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
|
||||||
return_value=mock_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "Virus detected" in response.text
|
|
||||||
mock_manager.write_file.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ---- No file extension ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Files without an extension should be accepted and stored as-is."""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
|
||||||
return_value=0,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.scan_content_safe",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mock_manager = mocker.MagicMock()
|
|
||||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
|
||||||
return_value=mock_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = _upload(
|
|
||||||
filename="Makefile",
|
|
||||||
content=b"all:\n\techo hello",
|
|
||||||
content_type="application/octet-stream",
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
mock_manager.write_file.assert_called_once()
|
|
||||||
assert mock_manager.write_file.call_args[0][1] == "Makefile"
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Filename sanitization (SF5) ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Path-traversal filenames should be reduced to their basename."""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_or_create_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace_total_size",
|
|
||||||
return_value=0,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.scan_content_safe",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mock_manager = mocker.MagicMock()
|
|
||||||
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
|
||||||
return_value=mock_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filename with traversal
|
|
||||||
_upload(filename="../../etc/passwd.txt")
|
|
||||||
|
|
||||||
# write_file should have been called with just the basename
|
|
||||||
mock_manager.write_file.assert_called_once()
|
|
||||||
call_args = mock_manager.write_file.call_args
|
|
||||||
assert call_args[0][1] == "passwd.txt"
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Download ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace_file",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/files/some-file-id/download")
|
|
||||||
assert response.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Delete ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_file_success(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Deleting an existing file should return {"deleted": true}."""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
mock_manager = mocker.MagicMock()
|
|
||||||
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
|
||||||
return_value=mock_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.delete("/files/file-aaa-bbb")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json() == {"deleted": True}
|
|
||||||
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Deleting a non-existent file should return 404."""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace",
|
|
||||||
return_value=MOCK_WORKSPACE,
|
|
||||||
)
|
|
||||||
mock_manager = mocker.MagicMock()
|
|
||||||
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.WorkspaceManager",
|
|
||||||
return_value=mock_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.delete("/files/nonexistent-id")
|
|
||||||
assert response.status_code == 404
|
|
||||||
assert "File not found" in response.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
|
|
||||||
"""Deleting when user has no workspace should return 404."""
|
|
||||||
mocker.patch(
|
|
||||||
"backend.api.features.workspace.routes.get_workspace",
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.delete("/files/file-aaa-bbb")
|
|
||||||
assert response.status_code == 404
|
|
||||||
assert "Workspace not found" in response.text
|
|
||||||
@@ -94,8 +94,3 @@ class NotificationPayload(pydantic.BaseModel):
|
|||||||
|
|
||||||
class OnboardingNotificationPayload(NotificationPayload):
|
class OnboardingNotificationPayload(NotificationPayload):
|
||||||
step: OnboardingStep | None
|
step: OnboardingStep | None
|
||||||
|
|
||||||
|
|
||||||
class CopilotCompletionPayload(NotificationPayload):
|
|
||||||
session_id: str
|
|
||||||
status: Literal["completed", "failed"]
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from prisma.errors import PrismaError
|
|||||||
import backend.api.features.admin.credit_admin_routes
|
import backend.api.features.admin.credit_admin_routes
|
||||||
import backend.api.features.admin.execution_analytics_routes
|
import backend.api.features.admin.execution_analytics_routes
|
||||||
import backend.api.features.admin.store_admin_routes
|
import backend.api.features.admin.store_admin_routes
|
||||||
import backend.api.features.admin.user_admin_routes
|
|
||||||
import backend.api.features.builder
|
import backend.api.features.builder
|
||||||
import backend.api.features.builder.routes
|
import backend.api.features.builder.routes
|
||||||
import backend.api.features.chat.routes as chat_routes
|
import backend.api.features.chat.routes as chat_routes
|
||||||
@@ -42,11 +41,11 @@ import backend.data.user
|
|||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
from backend.api.features.library.exceptions import (
|
|
||||||
FolderAlreadyExistsError,
|
|
||||||
FolderValidationError,
|
|
||||||
)
|
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
|
from backend.copilot.completion_consumer import (
|
||||||
|
start_completion_consumer,
|
||||||
|
stop_completion_consumer,
|
||||||
|
)
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.monitoring.instrumentation import instrument_fastapi
|
from backend.monitoring.instrumentation import instrument_fastapi
|
||||||
@@ -56,7 +55,6 @@ from backend.util.exceptions import (
|
|||||||
MissingConfigError,
|
MissingConfigError,
|
||||||
NotAuthorizedError,
|
NotAuthorizedError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
PreconditionFailed,
|
|
||||||
)
|
)
|
||||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||||
from backend.util.service import UnhealthyServiceError
|
from backend.util.service import UnhealthyServiceError
|
||||||
@@ -125,9 +123,21 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
|
# Start chat completion consumer for Redis Streams notifications
|
||||||
|
try:
|
||||||
|
await start_completion_consumer()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||||
|
|
||||||
with launch_darkly_context():
|
with launch_darkly_context():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# Stop chat completion consumer
|
||||||
|
try:
|
||||||
|
await stop_completion_consumer()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await shutdown_cloud_storage_handler()
|
await shutdown_cloud_storage_handler()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -267,17 +277,12 @@ async def validation_error_handler(
|
|||||||
|
|
||||||
|
|
||||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||||
app.add_exception_handler(
|
|
||||||
FolderAlreadyExistsError, handle_internal_http_error(409, False)
|
|
||||||
)
|
|
||||||
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
|
|
||||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||||
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
app.add_exception_handler(MissingConfigError, handle_internal_http_error(503))
|
||||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||||
app.add_exception_handler(PreconditionFailed, handle_internal_http_error(428))
|
|
||||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||||
|
|
||||||
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
app.include_router(backend.api.features.v1.v1_router, tags=["v1"], prefix="/api")
|
||||||
@@ -312,11 +317,6 @@ app.include_router(
|
|||||||
tags=["v2", "admin"],
|
tags=["v2", "admin"],
|
||||||
prefix="/api/executions",
|
prefix="/api/executions",
|
||||||
)
|
)
|
||||||
app.include_router(
|
|
||||||
backend.api.features.admin.user_admin_routes.router,
|
|
||||||
tags=["v2", "admin"],
|
|
||||||
prefix="/api/users",
|
|
||||||
)
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.executions.review.routes.router,
|
backend.api.features.executions.review.routes.router,
|
||||||
tags=["v2", "executions", "review"],
|
tags=["v2", "executions", "review"],
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
|
|||||||
# Run the last process in the foreground.
|
# Run the last process in the foreground.
|
||||||
processes[-1].start(background=False, **kwargs)
|
processes[-1].start(background=False, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
for process in reversed(processes):
|
for process in processes:
|
||||||
try:
|
try:
|
||||||
process.stop()
|
process.stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -418,8 +418,6 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
|||||||
|
|
||||||
|
|
||||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||||
_optimized_description: ClassVar[str | None] = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: str = "",
|
id: str = "",
|
||||||
@@ -472,8 +470,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
self.block_type = block_type
|
self.block_type = block_type
|
||||||
self.webhook_config = webhook_config
|
self.webhook_config = webhook_config
|
||||||
self.is_sensitive_action = is_sensitive_action
|
self.is_sensitive_action = is_sensitive_action
|
||||||
# Read from ClassVar set by initialize_blocks()
|
|
||||||
self.optimized_description: str | None = type(self)._optimized_description
|
|
||||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||||
|
|
||||||
if self.webhook_config:
|
if self.webhook_config:
|
||||||
@@ -624,7 +620,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_version: int,
|
graph_version: int,
|
||||||
execution_context: "ExecutionContext",
|
execution_context: "ExecutionContext",
|
||||||
is_graph_execution: bool = True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, BlockInput]:
|
) -> tuple[bool, BlockInput]:
|
||||||
"""
|
"""
|
||||||
@@ -653,7 +648,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
block_name=self.name,
|
block_name=self.name,
|
||||||
editable=True,
|
editable=True,
|
||||||
is_graph_execution=is_graph_execution,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if decision is None:
|
if decision is None:
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class PrintToConsoleBlock(Block):
|
|||||||
output_schema=PrintToConsoleBlock.Output,
|
output_schema=PrintToConsoleBlock.Output,
|
||||||
test_input={"text": "Hello, World!"},
|
test_input={"text": "Hello, World!"},
|
||||||
is_sensitive_action=True,
|
is_sensitive_action=True,
|
||||||
disabled=True,
|
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
|
||||||
test_output=[
|
test_output=[
|
||||||
("output", "Hello, World!"),
|
("output", "Hello, World!"),
|
||||||
("status", "printed"),
|
("status", "printed"),
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class BaseE2BExecutorMixin:
|
|||||||
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
start_timestamp = ts_result.stdout.strip() if ts_result.stdout else None
|
||||||
|
|
||||||
# Execute the code
|
# Execute the code
|
||||||
execution = await sandbox.run_code( # type: ignore[attr-defined]
|
execution = await sandbox.run_code(
|
||||||
code,
|
code,
|
||||||
language=language.value,
|
language=language.value,
|
||||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||||
|
|||||||
@@ -96,7 +96,6 @@ class SendEmailBlock(Block):
|
|||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_output=[("status", "Email sent successfully")],
|
test_output=[("status", "Email sent successfully")],
|
||||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||||
is_sensitive_action=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
def github_repo_path(repo_url: str) -> str:
|
|
||||||
"""Extract 'owner/repo' from a GitHub repository URL."""
|
|
||||||
return repo_url.replace("https://github.com/", "")
|
|
||||||
@@ -1,408 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from enum import StrEnum
|
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import parse_data_uri, resolve_media_content
|
|
||||||
from backend.util.type import MediaFileType
|
|
||||||
|
|
||||||
from ._api import get_api
|
|
||||||
from ._auth import (
|
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
GithubCredentials,
|
|
||||||
GithubCredentialsField,
|
|
||||||
GithubCredentialsInput,
|
|
||||||
)
|
|
||||||
from ._utils import github_repo_path
|
|
||||||
|
|
||||||
|
|
||||||
class GithubListCommitsBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch name to list commits from",
|
|
||||||
default="main",
|
|
||||||
)
|
|
||||||
per_page: int = SchemaField(
|
|
||||||
description="Number of commits to return (max 100)",
|
|
||||||
default=30,
|
|
||||||
ge=1,
|
|
||||||
le=100,
|
|
||||||
)
|
|
||||||
page: int = SchemaField(
|
|
||||||
description="Page number for pagination",
|
|
||||||
default=1,
|
|
||||||
ge=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
class CommitItem(TypedDict):
|
|
||||||
sha: str
|
|
||||||
message: str
|
|
||||||
author: str
|
|
||||||
date: str
|
|
||||||
url: str
|
|
||||||
|
|
||||||
commit: CommitItem = SchemaField(
|
|
||||||
title="Commit", description="A commit with its details"
|
|
||||||
)
|
|
||||||
commits: list[CommitItem] = SchemaField(
|
|
||||||
description="List of commits with their details"
|
|
||||||
)
|
|
||||||
error: str = SchemaField(description="Error message if listing commits failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8b13f579-d8b6-4dc2-a140-f770428805de",
|
|
||||||
description="This block lists commits on a branch in a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubListCommitsBlock.Input,
|
|
||||||
output_schema=GithubListCommitsBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"branch": "main",
|
|
||||||
"per_page": 30,
|
|
||||||
"page": 1,
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"commits",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"sha": "abc123",
|
|
||||||
"message": "Initial commit",
|
|
||||||
"author": "octocat",
|
|
||||||
"date": "2024-01-01T00:00:00Z",
|
|
||||||
"url": "https://github.com/owner/repo/commit/abc123",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"commit",
|
|
||||||
{
|
|
||||||
"sha": "abc123",
|
|
||||||
"message": "Initial commit",
|
|
||||||
"author": "octocat",
|
|
||||||
"date": "2024-01-01T00:00:00Z",
|
|
||||||
"url": "https://github.com/owner/repo/commit/abc123",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"list_commits": lambda *args, **kwargs: [
|
|
||||||
{
|
|
||||||
"sha": "abc123",
|
|
||||||
"message": "Initial commit",
|
|
||||||
"author": "octocat",
|
|
||||||
"date": "2024-01-01T00:00:00Z",
|
|
||||||
"url": "https://github.com/owner/repo/commit/abc123",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def list_commits(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
branch: str,
|
|
||||||
per_page: int,
|
|
||||||
page: int,
|
|
||||||
) -> list[Output.CommitItem]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
commits_url = repo_url + "/commits"
|
|
||||||
params = {"sha": branch, "per_page": str(per_page), "page": str(page)}
|
|
||||||
response = await api.get(commits_url, params=params)
|
|
||||||
data = response.json()
|
|
||||||
repo_path = github_repo_path(repo_url)
|
|
||||||
return [
|
|
||||||
GithubListCommitsBlock.Output.CommitItem(
|
|
||||||
sha=c["sha"],
|
|
||||||
message=c["commit"]["message"],
|
|
||||||
author=(c["commit"].get("author") or {}).get("name", "Unknown"),
|
|
||||||
date=(c["commit"].get("author") or {}).get("date", ""),
|
|
||||||
url=f"https://github.com/{repo_path}/commit/{c['sha']}",
|
|
||||||
)
|
|
||||||
for c in data
|
|
||||||
]
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
commits = await self.list_commits(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.branch,
|
|
||||||
input_data.per_page,
|
|
||||||
input_data.page,
|
|
||||||
)
|
|
||||||
yield "commits", commits
|
|
||||||
for commit in commits:
|
|
||||||
yield "commit", commit
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class FileOperation(StrEnum):
|
|
||||||
"""File operations for GithubMultiFileCommitBlock.
|
|
||||||
|
|
||||||
UPSERT creates or overwrites a file (the Git Trees API does not distinguish
|
|
||||||
between creation and update — the blob is placed at the given path regardless
|
|
||||||
of whether a file already exists there).
|
|
||||||
|
|
||||||
DELETE removes a file from the tree.
|
|
||||||
"""
|
|
||||||
|
|
||||||
UPSERT = "upsert"
|
|
||||||
DELETE = "delete"
|
|
||||||
|
|
||||||
|
|
||||||
class FileOperationInput(TypedDict):
|
|
||||||
path: str
|
|
||||||
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
|
||||||
content: MediaFileType
|
|
||||||
operation: FileOperation
|
|
||||||
|
|
||||||
|
|
||||||
class GithubMultiFileCommitBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch to commit to",
|
|
||||||
placeholder="feature-branch",
|
|
||||||
)
|
|
||||||
commit_message: str = SchemaField(
|
|
||||||
description="Commit message",
|
|
||||||
placeholder="Add new feature",
|
|
||||||
)
|
|
||||||
files: list[FileOperationInput] = SchemaField(
|
|
||||||
description=(
|
|
||||||
"List of file operations. Each item has: "
|
|
||||||
"'path' (file path), 'content' (file content, ignored for delete), "
|
|
||||||
"'operation' (upsert/delete)"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
sha: str = SchemaField(description="SHA of the new commit")
|
|
||||||
url: str = SchemaField(description="URL of the new commit")
|
|
||||||
error: str = SchemaField(description="Error message if the commit failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="389eee51-a95e-4230-9bed-92167a327802",
|
|
||||||
description=(
|
|
||||||
"This block creates a single commit with multiple file "
|
|
||||||
"upsert/delete operations using the Git Trees API."
|
|
||||||
),
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubMultiFileCommitBlock.Input,
|
|
||||||
output_schema=GithubMultiFileCommitBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"branch": "feature",
|
|
||||||
"commit_message": "Add files",
|
|
||||||
"files": [
|
|
||||||
{
|
|
||||||
"path": "src/new.py",
|
|
||||||
"content": "print('hello')",
|
|
||||||
"operation": "upsert",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "src/old.py",
|
|
||||||
"content": "",
|
|
||||||
"operation": "delete",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("sha", "newcommitsha"),
|
|
||||||
("url", "https://github.com/owner/repo/commit/newcommitsha"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"multi_file_commit": lambda *args, **kwargs: (
|
|
||||||
"newcommitsha",
|
|
||||||
"https://github.com/owner/repo/commit/newcommitsha",
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def multi_file_commit(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
branch: str,
|
|
||||||
commit_message: str,
|
|
||||||
files: list[FileOperationInput],
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
safe_branch = quote(branch, safe="")
|
|
||||||
|
|
||||||
# 1. Get the latest commit SHA for the branch
|
|
||||||
ref_url = repo_url + f"/git/refs/heads/{safe_branch}"
|
|
||||||
response = await api.get(ref_url)
|
|
||||||
ref_data = response.json()
|
|
||||||
latest_commit_sha = ref_data["object"]["sha"]
|
|
||||||
|
|
||||||
# 2. Get the tree SHA of the latest commit
|
|
||||||
commit_url = repo_url + f"/git/commits/{latest_commit_sha}"
|
|
||||||
response = await api.get(commit_url)
|
|
||||||
commit_data = response.json()
|
|
||||||
base_tree_sha = commit_data["tree"]["sha"]
|
|
||||||
|
|
||||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
|
||||||
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
|
||||||
blob_url = repo_url + "/git/blobs"
|
|
||||||
blob_response = await api.post(
|
|
||||||
blob_url,
|
|
||||||
json={"content": content, "encoding": encoding},
|
|
||||||
)
|
|
||||||
return blob_response.json()["sha"]
|
|
||||||
|
|
||||||
tree_entries: list[dict] = []
|
|
||||||
upsert_files = []
|
|
||||||
for file_op in files:
|
|
||||||
path = file_op["path"]
|
|
||||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
|
||||||
|
|
||||||
if operation == FileOperation.DELETE:
|
|
||||||
tree_entries.append(
|
|
||||||
{
|
|
||||||
"path": path,
|
|
||||||
"mode": "100644",
|
|
||||||
"type": "blob",
|
|
||||||
"sha": None, # null SHA = delete
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
upsert_files.append((path, file_op.get("content", "")))
|
|
||||||
|
|
||||||
# Create all blobs concurrently. Data URIs (from store_media_file)
|
|
||||||
# are sent as base64 blobs to preserve binary content.
|
|
||||||
if upsert_files:
|
|
||||||
|
|
||||||
async def _make_blob(content: str) -> str:
|
|
||||||
parsed = parse_data_uri(content)
|
|
||||||
if parsed is not None:
|
|
||||||
_, b64_payload = parsed
|
|
||||||
return await _create_blob(b64_payload, encoding="base64")
|
|
||||||
return await _create_blob(content)
|
|
||||||
|
|
||||||
blob_shas = await asyncio.gather(
|
|
||||||
*[_make_blob(content) for _, content in upsert_files]
|
|
||||||
)
|
|
||||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
|
||||||
tree_entries.append(
|
|
||||||
{
|
|
||||||
"path": path,
|
|
||||||
"mode": "100644",
|
|
||||||
"type": "blob",
|
|
||||||
"sha": blob_sha,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Create a new tree
|
|
||||||
tree_url = repo_url + "/git/trees"
|
|
||||||
tree_response = await api.post(
|
|
||||||
tree_url,
|
|
||||||
json={"base_tree": base_tree_sha, "tree": tree_entries},
|
|
||||||
)
|
|
||||||
new_tree_sha = tree_response.json()["sha"]
|
|
||||||
|
|
||||||
# 5. Create a new commit
|
|
||||||
new_commit_url = repo_url + "/git/commits"
|
|
||||||
commit_response = await api.post(
|
|
||||||
new_commit_url,
|
|
||||||
json={
|
|
||||||
"message": commit_message,
|
|
||||||
"tree": new_tree_sha,
|
|
||||||
"parents": [latest_commit_sha],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
new_commit_sha = commit_response.json()["sha"]
|
|
||||||
|
|
||||||
# 6. Update the branch reference
|
|
||||||
try:
|
|
||||||
await api.patch(
|
|
||||||
ref_url,
|
|
||||||
json={"sha": new_commit_sha},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Commit {new_commit_sha} was created but failed to update "
|
|
||||||
f"ref heads/{branch}: {e}. "
|
|
||||||
f"You can recover by manually updating the branch to {new_commit_sha}."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
repo_path = github_repo_path(repo_url)
|
|
||||||
commit_web_url = f"https://github.com/{repo_path}/commit/{new_commit_sha}"
|
|
||||||
return new_commit_sha, commit_web_url
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
# Resolve media references (workspace://, data:, URLs) to data
|
|
||||||
# URIs so _make_blob can send binary content correctly.
|
|
||||||
resolved_files: list[FileOperationInput] = []
|
|
||||||
for file_op in input_data.files:
|
|
||||||
content = file_op.get("content", "")
|
|
||||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
|
||||||
if operation != FileOperation.DELETE:
|
|
||||||
content = await resolve_media_content(
|
|
||||||
MediaFileType(content),
|
|
||||||
execution_context,
|
|
||||||
return_format="for_external_api",
|
|
||||||
)
|
|
||||||
resolved_files.append(
|
|
||||||
FileOperationInput(
|
|
||||||
path=file_op["path"],
|
|
||||||
content=MediaFileType(content),
|
|
||||||
operation=operation,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
sha, url = await self.multi_file_commit(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.branch,
|
|
||||||
input_data.commit_message,
|
|
||||||
resolved_files,
|
|
||||||
)
|
|
||||||
yield "sha", sha
|
|
||||||
yield "url", url
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
@@ -21,8 +20,6 @@ from ._auth import (
|
|||||||
GithubCredentialsInput,
|
GithubCredentialsInput,
|
||||||
)
|
)
|
||||||
|
|
||||||
MergeMethod = Literal["merge", "squash", "rebase"]
|
|
||||||
|
|
||||||
|
|
||||||
class GithubListPullRequestsBlock(Block):
|
class GithubListPullRequestsBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
@@ -561,109 +558,12 @@ class GithubListPRReviewersBlock(Block):
|
|||||||
yield "reviewer", reviewer
|
yield "reviewer", reviewer
|
||||||
|
|
||||||
|
|
||||||
class GithubMergePullRequestBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
pr_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub pull request",
|
|
||||||
placeholder="https://github.com/owner/repo/pull/1",
|
|
||||||
)
|
|
||||||
merge_method: MergeMethod = SchemaField(
|
|
||||||
description="Merge method to use: merge, squash, or rebase",
|
|
||||||
default="merge",
|
|
||||||
)
|
|
||||||
commit_title: str = SchemaField(
|
|
||||||
description="Title for the merge commit (optional, used for merge and squash)",
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
commit_message: str = SchemaField(
|
|
||||||
description="Message for the merge commit (optional, used for merge and squash)",
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
sha: str = SchemaField(description="SHA of the merge commit")
|
|
||||||
merged: bool = SchemaField(description="Whether the PR was merged")
|
|
||||||
message: str = SchemaField(description="Merge status message")
|
|
||||||
error: str = SchemaField(description="Error message if the merge failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="77456c22-33d8-4fd4-9eef-50b46a35bb48",
|
|
||||||
description="This block merges a pull request using merge, squash, or rebase.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubMergePullRequestBlock.Input,
|
|
||||||
output_schema=GithubMergePullRequestBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
|
||||||
"merge_method": "squash",
|
|
||||||
"commit_title": "",
|
|
||||||
"commit_message": "",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("sha", "abc123"),
|
|
||||||
("merged", True),
|
|
||||||
("message", "Pull Request successfully merged"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"merge_pr": lambda *args, **kwargs: (
|
|
||||||
"abc123",
|
|
||||||
True,
|
|
||||||
"Pull Request successfully merged",
|
|
||||||
)
|
|
||||||
},
|
|
||||||
is_sensitive_action=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def merge_pr(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
pr_url: str,
|
|
||||||
merge_method: MergeMethod,
|
|
||||||
commit_title: str,
|
|
||||||
commit_message: str,
|
|
||||||
) -> tuple[str, bool, str]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
merge_url = prepare_pr_api_url(pr_url=pr_url, path="merge")
|
|
||||||
data: dict[str, str] = {"merge_method": merge_method}
|
|
||||||
if commit_title:
|
|
||||||
data["commit_title"] = commit_title
|
|
||||||
if commit_message:
|
|
||||||
data["commit_message"] = commit_message
|
|
||||||
response = await api.put(merge_url, json=data)
|
|
||||||
result = response.json()
|
|
||||||
return result["sha"], result["merged"], result["message"]
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
sha, merged, message = await self.merge_pr(
|
|
||||||
credentials,
|
|
||||||
input_data.pr_url,
|
|
||||||
input_data.merge_method,
|
|
||||||
input_data.commit_title,
|
|
||||||
input_data.commit_message,
|
|
||||||
)
|
|
||||||
yield "sha", sha
|
|
||||||
yield "merged", merged
|
|
||||||
yield "message", message
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||||
# Pattern to capture the base repository URL and the pull request number
|
# Pattern to capture the base repository URL and the pull request number
|
||||||
pattern = r"^(?:(https?)://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||||
match = re.match(pattern, pr_url)
|
match = re.match(pattern, pr_url)
|
||||||
if not match:
|
if not match:
|
||||||
return pr_url
|
return pr_url
|
||||||
|
|
||||||
scheme, base_url, pr_number = match.groups()
|
base_url, pr_number = match.groups()
|
||||||
return f"{scheme or 'https'}://{base_url}/pulls/{pr_number}/{path}"
|
return f"{base_url}/pulls/{pr_number}/{path}"
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks._base import (
|
||||||
@@ -17,7 +19,6 @@ from ._auth import (
|
|||||||
GithubCredentialsField,
|
GithubCredentialsField,
|
||||||
GithubCredentialsInput,
|
GithubCredentialsInput,
|
||||||
)
|
)
|
||||||
from ._utils import github_repo_path
|
|
||||||
|
|
||||||
|
|
||||||
class GithubListTagsBlock(Block):
|
class GithubListTagsBlock(Block):
|
||||||
@@ -88,7 +89,7 @@ class GithubListTagsBlock(Block):
|
|||||||
tags_url = repo_url + "/tags"
|
tags_url = repo_url + "/tags"
|
||||||
response = await api.get(tags_url)
|
response = await api.get(tags_url)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
repo_path = github_repo_path(repo_url)
|
repo_path = repo_url.replace("https://github.com/", "")
|
||||||
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
||||||
{
|
{
|
||||||
"name": tag["name"],
|
"name": tag["name"],
|
||||||
@@ -114,6 +115,101 @@ class GithubListTagsBlock(Block):
|
|||||||
yield "tag", tag
|
yield "tag", tag
|
||||||
|
|
||||||
|
|
||||||
|
class GithubListBranchesBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
class BranchItem(TypedDict):
|
||||||
|
name: str
|
||||||
|
url: str
|
||||||
|
|
||||||
|
branch: BranchItem = SchemaField(
|
||||||
|
title="Branch",
|
||||||
|
description="Branches with their name and file tree browser URL",
|
||||||
|
)
|
||||||
|
branches: list[BranchItem] = SchemaField(
|
||||||
|
description="List of branches with their name and file tree browser URL"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||||
|
description="This block lists all branches for a specified GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubListBranchesBlock.Input,
|
||||||
|
output_schema=GithubListBranchesBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"branches",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "main",
|
||||||
|
"url": "https://github.com/owner/repo/tree/main",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"branch",
|
||||||
|
{
|
||||||
|
"name": "main",
|
||||||
|
"url": "https://github.com/owner/repo/tree/main",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"list_branches": lambda *args, **kwargs: [
|
||||||
|
{
|
||||||
|
"name": "main",
|
||||||
|
"url": "https://github.com/owner/repo/tree/main",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def list_branches(
|
||||||
|
credentials: GithubCredentials, repo_url: str
|
||||||
|
) -> list[Output.BranchItem]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
branches_url = repo_url + "/branches"
|
||||||
|
response = await api.get(branches_url)
|
||||||
|
data = response.json()
|
||||||
|
repo_path = repo_url.replace("https://github.com/", "")
|
||||||
|
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||||
|
{
|
||||||
|
"name": branch["name"],
|
||||||
|
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
||||||
|
}
|
||||||
|
for branch in data
|
||||||
|
]
|
||||||
|
return branches
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
branches = await self.list_branches(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
)
|
||||||
|
yield "branches", branches
|
||||||
|
for branch in branches:
|
||||||
|
yield "branch", branch
|
||||||
|
|
||||||
|
|
||||||
class GithubListDiscussionsBlock(Block):
|
class GithubListDiscussionsBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
@@ -187,7 +283,7 @@ class GithubListDiscussionsBlock(Block):
|
|||||||
) -> list[Output.DiscussionItem]:
|
) -> list[Output.DiscussionItem]:
|
||||||
api = get_api(credentials)
|
api = get_api(credentials)
|
||||||
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
# GitHub GraphQL API endpoint is different; we'll use api.post with custom URL
|
||||||
repo_path = github_repo_path(repo_url)
|
repo_path = repo_url.replace("https://github.com/", "")
|
||||||
owner, repo = repo_path.split("/")
|
owner, repo = repo_path.split("/")
|
||||||
query = """
|
query = """
|
||||||
query($owner: String!, $repo: String!, $num: Int!) {
|
query($owner: String!, $repo: String!, $num: Int!) {
|
||||||
@@ -320,6 +416,564 @@ class GithubListReleasesBlock(Block):
|
|||||||
yield "release", release
|
yield "release", release
|
||||||
|
|
||||||
|
|
||||||
|
class GithubReadFileBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
file_path: str = SchemaField(
|
||||||
|
description="Path to the file in the repository",
|
||||||
|
placeholder="path/to/file",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch to read from",
|
||||||
|
placeholder="branch_name",
|
||||||
|
default="master",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
text_content: str = SchemaField(
|
||||||
|
description="Content of the file (decoded as UTF-8 text)"
|
||||||
|
)
|
||||||
|
raw_content: str = SchemaField(
|
||||||
|
description="Raw base64-encoded content of the file"
|
||||||
|
)
|
||||||
|
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||||
|
description="This block reads the content of a specified file from a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubReadFileBlock.Input,
|
||||||
|
output_schema=GithubReadFileBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"file_path": "path/to/file",
|
||||||
|
"branch": "master",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("raw_content", "RmlsZSBjb250ZW50"),
|
||||||
|
("text_content", "File content"),
|
||||||
|
("size", 13),
|
||||||
|
],
|
||||||
|
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def read_file(
|
||||||
|
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||||
|
) -> tuple[str, int]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
content_url = repo_url + f"/contents/{file_path}?ref={branch}"
|
||||||
|
response = await api.get(content_url)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if isinstance(data, list):
|
||||||
|
# Multiple entries of different types exist at this path
|
||||||
|
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
||||||
|
raise TypeError("Not a file")
|
||||||
|
data = file
|
||||||
|
|
||||||
|
if data["type"] != "file":
|
||||||
|
raise TypeError("Not a file")
|
||||||
|
|
||||||
|
return data["content"], data["size"]
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
content, size = await self.read_file(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.file_path,
|
||||||
|
input_data.branch,
|
||||||
|
)
|
||||||
|
yield "raw_content", content
|
||||||
|
yield "text_content", base64.b64decode(content).decode("utf-8")
|
||||||
|
yield "size", size
|
||||||
|
|
||||||
|
|
||||||
|
class GithubReadFolderBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
folder_path: str = SchemaField(
|
||||||
|
description="Path to the folder in the repository",
|
||||||
|
placeholder="path/to/folder",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch name to read from (defaults to master)",
|
||||||
|
placeholder="branch_name",
|
||||||
|
default="master",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
class DirEntry(TypedDict):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
|
||||||
|
class FileEntry(TypedDict):
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
size: int
|
||||||
|
|
||||||
|
file: FileEntry = SchemaField(description="Files in the folder")
|
||||||
|
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if reading the folder failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||||
|
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubReadFolderBlock.Input,
|
||||||
|
output_schema=GithubReadFolderBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"folder_path": "path/to/folder",
|
||||||
|
"branch": "master",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"file",
|
||||||
|
{
|
||||||
|
"name": "file1.txt",
|
||||||
|
"path": "path/to/folder/file1.txt",
|
||||||
|
"size": 1337,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"read_folder": lambda *args, **kwargs: (
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "file1.txt",
|
||||||
|
"path": "path/to/folder/file1.txt",
|
||||||
|
"size": 1337,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def read_folder(
|
||||||
|
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||||
|
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
contents_url = repo_url + f"/contents/{folder_path}?ref={branch}"
|
||||||
|
response = await api.get(contents_url)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise TypeError("Not a folder")
|
||||||
|
|
||||||
|
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
||||||
|
GithubReadFolderBlock.Output.FileEntry(
|
||||||
|
name=entry["name"],
|
||||||
|
path=entry["path"],
|
||||||
|
size=entry["size"],
|
||||||
|
)
|
||||||
|
for entry in data
|
||||||
|
if entry["type"] == "file"
|
||||||
|
]
|
||||||
|
|
||||||
|
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
||||||
|
GithubReadFolderBlock.Output.DirEntry(
|
||||||
|
name=entry["name"],
|
||||||
|
path=entry["path"],
|
||||||
|
)
|
||||||
|
for entry in data
|
||||||
|
if entry["type"] == "dir"
|
||||||
|
]
|
||||||
|
|
||||||
|
return files, dirs
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
files, dirs = await self.read_folder(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.folder_path.lstrip("/"),
|
||||||
|
input_data.branch,
|
||||||
|
)
|
||||||
|
for file in files:
|
||||||
|
yield "file", file
|
||||||
|
for dir in dirs:
|
||||||
|
yield "dir", dir
|
||||||
|
|
||||||
|
|
||||||
|
class GithubMakeBranchBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
new_branch: str = SchemaField(
|
||||||
|
description="Name of the new branch",
|
||||||
|
placeholder="new_branch_name",
|
||||||
|
)
|
||||||
|
source_branch: str = SchemaField(
|
||||||
|
description="Name of the source branch",
|
||||||
|
placeholder="source_branch_name",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
status: str = SchemaField(description="Status of the branch creation operation")
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the branch creation failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||||
|
description="This block creates a new branch from a specified source branch.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubMakeBranchBlock.Input,
|
||||||
|
output_schema=GithubMakeBranchBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"new_branch": "new_branch_name",
|
||||||
|
"source_branch": "source_branch_name",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[("status", "Branch created successfully")],
|
||||||
|
test_mock={
|
||||||
|
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_branch(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
new_branch: str,
|
||||||
|
source_branch: str,
|
||||||
|
) -> str:
|
||||||
|
api = get_api(credentials)
|
||||||
|
ref_url = repo_url + f"/git/refs/heads/{source_branch}"
|
||||||
|
response = await api.get(ref_url)
|
||||||
|
data = response.json()
|
||||||
|
sha = data["object"]["sha"]
|
||||||
|
|
||||||
|
# Create the new branch
|
||||||
|
new_ref_url = repo_url + "/git/refs"
|
||||||
|
data = {
|
||||||
|
"ref": f"refs/heads/{new_branch}",
|
||||||
|
"sha": sha,
|
||||||
|
}
|
||||||
|
response = await api.post(new_ref_url, json=data)
|
||||||
|
return "Branch created successfully"
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
status = await self.create_branch(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.new_branch,
|
||||||
|
input_data.source_branch,
|
||||||
|
)
|
||||||
|
yield "status", status
|
||||||
|
|
||||||
|
|
||||||
|
class GithubDeleteBranchBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Name of the branch to delete",
|
||||||
|
placeholder="branch_name",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the branch deletion failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||||
|
description="This block deletes a specified branch.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubDeleteBranchBlock.Input,
|
||||||
|
output_schema=GithubDeleteBranchBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"branch": "branch_name",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[("status", "Branch deleted successfully")],
|
||||||
|
test_mock={
|
||||||
|
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def delete_branch(
|
||||||
|
credentials: GithubCredentials, repo_url: str, branch: str
|
||||||
|
) -> str:
|
||||||
|
api = get_api(credentials)
|
||||||
|
ref_url = repo_url + f"/git/refs/heads/{branch}"
|
||||||
|
await api.delete(ref_url)
|
||||||
|
return "Branch deleted successfully"
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
status = await self.delete_branch(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.branch,
|
||||||
|
)
|
||||||
|
yield "status", status
|
||||||
|
|
||||||
|
|
||||||
|
class GithubCreateFileBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
file_path: str = SchemaField(
|
||||||
|
description="Path where the file should be created",
|
||||||
|
placeholder="path/to/file.txt",
|
||||||
|
)
|
||||||
|
content: str = SchemaField(
|
||||||
|
description="Content to write to the file",
|
||||||
|
placeholder="File content here",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch where the file should be created",
|
||||||
|
default="main",
|
||||||
|
)
|
||||||
|
commit_message: str = SchemaField(
|
||||||
|
description="Message for the commit",
|
||||||
|
default="Create new file",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
url: str = SchemaField(description="URL of the created file")
|
||||||
|
sha: str = SchemaField(description="SHA of the commit")
|
||||||
|
error: str = SchemaField(
|
||||||
|
description="Error message if the file creation failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||||
|
description="This block creates a new file in a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubCreateFileBlock.Input,
|
||||||
|
output_schema=GithubCreateFileBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"file_path": "test/file.txt",
|
||||||
|
"content": "Test content",
|
||||||
|
"branch": "main",
|
||||||
|
"commit_message": "Create test file",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||||
|
("sha", "abc123"),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"create_file": lambda *args, **kwargs: (
|
||||||
|
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||||
|
"abc123",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_file(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
file_path: str,
|
||||||
|
content: str,
|
||||||
|
branch: str,
|
||||||
|
commit_message: str,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
contents_url = repo_url + f"/contents/{file_path}"
|
||||||
|
content_base64 = base64.b64encode(content.encode()).decode()
|
||||||
|
data = {
|
||||||
|
"message": commit_message,
|
||||||
|
"content": content_base64,
|
||||||
|
"branch": branch,
|
||||||
|
}
|
||||||
|
response = await api.put(contents_url, json=data)
|
||||||
|
data = response.json()
|
||||||
|
return data["content"]["html_url"], data["commit"]["sha"]
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
url, sha = await self.create_file(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.file_path,
|
||||||
|
input_data.content,
|
||||||
|
input_data.branch,
|
||||||
|
input_data.commit_message,
|
||||||
|
)
|
||||||
|
yield "url", url
|
||||||
|
yield "sha", sha
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class GithubUpdateFileBlock(Block):
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
|
repo_url: str = SchemaField(
|
||||||
|
description="URL of the GitHub repository",
|
||||||
|
placeholder="https://github.com/owner/repo",
|
||||||
|
)
|
||||||
|
file_path: str = SchemaField(
|
||||||
|
description="Path to the file to update",
|
||||||
|
placeholder="path/to/file.txt",
|
||||||
|
)
|
||||||
|
content: str = SchemaField(
|
||||||
|
description="New content for the file",
|
||||||
|
placeholder="Updated content here",
|
||||||
|
)
|
||||||
|
branch: str = SchemaField(
|
||||||
|
description="Branch containing the file",
|
||||||
|
default="main",
|
||||||
|
)
|
||||||
|
commit_message: str = SchemaField(
|
||||||
|
description="Message for the commit",
|
||||||
|
default="Update file",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
url: str = SchemaField(description="URL of the updated file")
|
||||||
|
sha: str = SchemaField(description="SHA of the commit")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||||
|
description="This block updates an existing file in a GitHub repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=GithubUpdateFileBlock.Input,
|
||||||
|
output_schema=GithubUpdateFileBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"repo_url": "https://github.com/owner/repo",
|
||||||
|
"file_path": "test/file.txt",
|
||||||
|
"content": "Updated content",
|
||||||
|
"branch": "main",
|
||||||
|
"commit_message": "Update test file",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||||
|
("sha", "def456"),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"update_file": lambda *args, **kwargs: (
|
||||||
|
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||||
|
"def456",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_file(
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
repo_url: str,
|
||||||
|
file_path: str,
|
||||||
|
content: str,
|
||||||
|
branch: str,
|
||||||
|
commit_message: str,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
api = get_api(credentials)
|
||||||
|
contents_url = repo_url + f"/contents/{file_path}"
|
||||||
|
params = {"ref": branch}
|
||||||
|
response = await api.get(contents_url, params=params)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Convert new content to base64
|
||||||
|
content_base64 = base64.b64encode(content.encode()).decode()
|
||||||
|
data = {
|
||||||
|
"message": commit_message,
|
||||||
|
"content": content_base64,
|
||||||
|
"sha": data["sha"],
|
||||||
|
"branch": branch,
|
||||||
|
}
|
||||||
|
response = await api.put(contents_url, json=data)
|
||||||
|
data = response.json()
|
||||||
|
return data["content"]["html_url"], data["commit"]["sha"]
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: GithubCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
url, sha = await self.update_file(
|
||||||
|
credentials,
|
||||||
|
input_data.repo_url,
|
||||||
|
input_data.file_path,
|
||||||
|
input_data.content,
|
||||||
|
input_data.branch,
|
||||||
|
input_data.commit_message,
|
||||||
|
)
|
||||||
|
yield "url", url
|
||||||
|
yield "sha", sha
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|
||||||
class GithubCreateRepositoryBlock(Block):
|
class GithubCreateRepositoryBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||||
@@ -449,7 +1103,7 @@ class GithubListStargazersBlock(Block):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="e96d01ec-b55e-4a99-8ce8-c8776dce850b", # Generated unique UUID
|
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
|
||||||
description="This block lists all users who have starred a specified GitHub repository.",
|
description="This block lists all users who have starred a specified GitHub repository.",
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
input_schema=GithubListStargazersBlock.Input,
|
input_schema=GithubListStargazersBlock.Input,
|
||||||
@@ -518,230 +1172,3 @@ class GithubListStargazersBlock(Block):
|
|||||||
yield "stargazers", stargazers
|
yield "stargazers", stargazers
|
||||||
for stargazer in stargazers:
|
for stargazer in stargazers:
|
||||||
yield "stargazer", stargazer
|
yield "stargazer", stargazer
|
||||||
|
|
||||||
|
|
||||||
class GithubGetRepositoryInfoBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
name: str = SchemaField(description="Repository name")
|
|
||||||
full_name: str = SchemaField(description="Full repository name (owner/repo)")
|
|
||||||
description: str = SchemaField(description="Repository description")
|
|
||||||
default_branch: str = SchemaField(description="Default branch name (e.g. main)")
|
|
||||||
private: bool = SchemaField(description="Whether the repository is private")
|
|
||||||
html_url: str = SchemaField(description="Web URL of the repository")
|
|
||||||
clone_url: str = SchemaField(description="Git clone URL")
|
|
||||||
stars: int = SchemaField(description="Number of stars")
|
|
||||||
forks: int = SchemaField(description="Number of forks")
|
|
||||||
open_issues: int = SchemaField(description="Number of open issues")
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if fetching repo info failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="59d4f241-968a-4040-95da-348ac5c5ce27",
|
|
||||||
description="This block retrieves metadata about a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubGetRepositoryInfoBlock.Input,
|
|
||||||
output_schema=GithubGetRepositoryInfoBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("name", "repo"),
|
|
||||||
("full_name", "owner/repo"),
|
|
||||||
("description", "A test repo"),
|
|
||||||
("default_branch", "main"),
|
|
||||||
("private", False),
|
|
||||||
("html_url", "https://github.com/owner/repo"),
|
|
||||||
("clone_url", "https://github.com/owner/repo.git"),
|
|
||||||
("stars", 42),
|
|
||||||
("forks", 5),
|
|
||||||
("open_issues", 3),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"get_repo_info": lambda *args, **kwargs: {
|
|
||||||
"name": "repo",
|
|
||||||
"full_name": "owner/repo",
|
|
||||||
"description": "A test repo",
|
|
||||||
"default_branch": "main",
|
|
||||||
"private": False,
|
|
||||||
"html_url": "https://github.com/owner/repo",
|
|
||||||
"clone_url": "https://github.com/owner/repo.git",
|
|
||||||
"stargazers_count": 42,
|
|
||||||
"forks_count": 5,
|
|
||||||
"open_issues_count": 3,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_repo_info(credentials: GithubCredentials, repo_url: str) -> dict:
|
|
||||||
api = get_api(credentials)
|
|
||||||
response = await api.get(repo_url)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
data = await self.get_repo_info(credentials, input_data.repo_url)
|
|
||||||
yield "name", data["name"]
|
|
||||||
yield "full_name", data["full_name"]
|
|
||||||
yield "description", data.get("description", "") or ""
|
|
||||||
yield "default_branch", data["default_branch"]
|
|
||||||
yield "private", data["private"]
|
|
||||||
yield "html_url", data["html_url"]
|
|
||||||
yield "clone_url", data["clone_url"]
|
|
||||||
yield "stars", data["stargazers_count"]
|
|
||||||
yield "forks", data["forks_count"]
|
|
||||||
yield "open_issues", data["open_issues_count"]
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubForkRepositoryBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository to fork",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
organization: str = SchemaField(
|
|
||||||
description="Organization to fork into (leave empty to fork to your account)",
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
url: str = SchemaField(description="URL of the forked repository")
|
|
||||||
clone_url: str = SchemaField(description="Git clone URL of the fork")
|
|
||||||
full_name: str = SchemaField(description="Full name of the fork (owner/repo)")
|
|
||||||
error: str = SchemaField(description="Error message if the fork failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="a439f2f4-835f-4dae-ba7b-0205ffa70be6",
|
|
||||||
description="This block forks a GitHub repository to your account or an organization.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubForkRepositoryBlock.Input,
|
|
||||||
output_schema=GithubForkRepositoryBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"organization": "",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("url", "https://github.com/myuser/repo"),
|
|
||||||
("clone_url", "https://github.com/myuser/repo.git"),
|
|
||||||
("full_name", "myuser/repo"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"fork_repo": lambda *args, **kwargs: (
|
|
||||||
"https://github.com/myuser/repo",
|
|
||||||
"https://github.com/myuser/repo.git",
|
|
||||||
"myuser/repo",
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def fork_repo(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
organization: str,
|
|
||||||
) -> tuple[str, str, str]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
forks_url = repo_url + "/forks"
|
|
||||||
data: dict[str, str] = {}
|
|
||||||
if organization:
|
|
||||||
data["organization"] = organization
|
|
||||||
response = await api.post(forks_url, json=data)
|
|
||||||
result = response.json()
|
|
||||||
return result["html_url"], result["clone_url"], result["full_name"]
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
url, clone_url, full_name = await self.fork_repo(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.organization,
|
|
||||||
)
|
|
||||||
yield "url", url
|
|
||||||
yield "clone_url", clone_url
|
|
||||||
yield "full_name", full_name
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubStarRepositoryBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository to star",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
status: str = SchemaField(description="Status of the star operation")
|
|
||||||
error: str = SchemaField(description="Error message if starring failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="bd700764-53e3-44dd-a969-d1854088458f",
|
|
||||||
description="This block stars a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubStarRepositoryBlock.Input,
|
|
||||||
output_schema=GithubStarRepositoryBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[("status", "Repository starred successfully")],
|
|
||||||
test_mock={
|
|
||||||
"star_repo": lambda *args, **kwargs: "Repository starred successfully"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def star_repo(credentials: GithubCredentials, repo_url: str) -> str:
|
|
||||||
api = get_api(credentials, convert_urls=False)
|
|
||||||
repo_path = github_repo_path(repo_url)
|
|
||||||
owner, repo = repo_path.split("/")
|
|
||||||
await api.put(
|
|
||||||
f"https://api.github.com/user/starred/{owner}/{repo}",
|
|
||||||
headers={"Content-Length": "0"},
|
|
||||||
)
|
|
||||||
return "Repository starred successfully"
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
status = await self.star_repo(credentials, input_data.repo_url)
|
|
||||||
yield "status", status
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|||||||
@@ -1,452 +0,0 @@
|
|||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
|
|
||||||
from ._api import get_api
|
|
||||||
from ._auth import (
|
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
GithubCredentials,
|
|
||||||
GithubCredentialsField,
|
|
||||||
GithubCredentialsInput,
|
|
||||||
)
|
|
||||||
from ._utils import github_repo_path
|
|
||||||
|
|
||||||
|
|
||||||
class GithubListBranchesBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
per_page: int = SchemaField(
|
|
||||||
description="Number of branches to return per page (max 100)",
|
|
||||||
default=30,
|
|
||||||
ge=1,
|
|
||||||
le=100,
|
|
||||||
)
|
|
||||||
page: int = SchemaField(
|
|
||||||
description="Page number for pagination",
|
|
||||||
default=1,
|
|
||||||
ge=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
class BranchItem(TypedDict):
|
|
||||||
name: str
|
|
||||||
url: str
|
|
||||||
|
|
||||||
branch: BranchItem = SchemaField(
|
|
||||||
title="Branch",
|
|
||||||
description="Branches with their name and file tree browser URL",
|
|
||||||
)
|
|
||||||
branches: list[BranchItem] = SchemaField(
|
|
||||||
description="List of branches with their name and file tree browser URL"
|
|
||||||
)
|
|
||||||
error: str = SchemaField(description="Error message if listing branches failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
|
||||||
description="This block lists all branches for a specified GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubListBranchesBlock.Input,
|
|
||||||
output_schema=GithubListBranchesBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"per_page": 30,
|
|
||||||
"page": 1,
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"branches",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "main",
|
|
||||||
"url": "https://github.com/owner/repo/tree/main",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"branch",
|
|
||||||
{
|
|
||||||
"name": "main",
|
|
||||||
"url": "https://github.com/owner/repo/tree/main",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"list_branches": lambda *args, **kwargs: [
|
|
||||||
{
|
|
||||||
"name": "main",
|
|
||||||
"url": "https://github.com/owner/repo/tree/main",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def list_branches(
|
|
||||||
credentials: GithubCredentials, repo_url: str, per_page: int, page: int
|
|
||||||
) -> list[Output.BranchItem]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
branches_url = repo_url + "/branches"
|
|
||||||
response = await api.get(
|
|
||||||
branches_url, params={"per_page": str(per_page), "page": str(page)}
|
|
||||||
)
|
|
||||||
data = response.json()
|
|
||||||
repo_path = github_repo_path(repo_url)
|
|
||||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
|
||||||
{
|
|
||||||
"name": branch["name"],
|
|
||||||
"url": f"https://github.com/{repo_path}/tree/{branch['name']}",
|
|
||||||
}
|
|
||||||
for branch in data
|
|
||||||
]
|
|
||||||
return branches
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
branches = await self.list_branches(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.per_page,
|
|
||||||
input_data.page,
|
|
||||||
)
|
|
||||||
yield "branches", branches
|
|
||||||
for branch in branches:
|
|
||||||
yield "branch", branch
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubMakeBranchBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
new_branch: str = SchemaField(
|
|
||||||
description="Name of the new branch",
|
|
||||||
placeholder="new_branch_name",
|
|
||||||
)
|
|
||||||
source_branch: str = SchemaField(
|
|
||||||
description="Name of the source branch",
|
|
||||||
placeholder="source_branch_name",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
status: str = SchemaField(description="Status of the branch creation operation")
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if the branch creation failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
|
||||||
description="This block creates a new branch from a specified source branch.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubMakeBranchBlock.Input,
|
|
||||||
output_schema=GithubMakeBranchBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"new_branch": "new_branch_name",
|
|
||||||
"source_branch": "source_branch_name",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[("status", "Branch created successfully")],
|
|
||||||
test_mock={
|
|
||||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create_branch(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
new_branch: str,
|
|
||||||
source_branch: str,
|
|
||||||
) -> str:
|
|
||||||
api = get_api(credentials)
|
|
||||||
ref_url = repo_url + f"/git/refs/heads/{quote(source_branch, safe='')}"
|
|
||||||
response = await api.get(ref_url)
|
|
||||||
data = response.json()
|
|
||||||
sha = data["object"]["sha"]
|
|
||||||
|
|
||||||
# Create the new branch
|
|
||||||
new_ref_url = repo_url + "/git/refs"
|
|
||||||
data = {
|
|
||||||
"ref": f"refs/heads/{new_branch}",
|
|
||||||
"sha": sha,
|
|
||||||
}
|
|
||||||
response = await api.post(new_ref_url, json=data)
|
|
||||||
return "Branch created successfully"
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
status = await self.create_branch(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.new_branch,
|
|
||||||
input_data.source_branch,
|
|
||||||
)
|
|
||||||
yield "status", status
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubDeleteBranchBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Name of the branch to delete",
|
|
||||||
placeholder="branch_name",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if the branch deletion failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
|
||||||
description="This block deletes a specified branch.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubDeleteBranchBlock.Input,
|
|
||||||
output_schema=GithubDeleteBranchBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"branch": "branch_name",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[("status", "Branch deleted successfully")],
|
|
||||||
test_mock={
|
|
||||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
|
||||||
},
|
|
||||||
is_sensitive_action=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def delete_branch(
|
|
||||||
credentials: GithubCredentials, repo_url: str, branch: str
|
|
||||||
) -> str:
|
|
||||||
api = get_api(credentials)
|
|
||||||
ref_url = repo_url + f"/git/refs/heads/{quote(branch, safe='')}"
|
|
||||||
await api.delete(ref_url)
|
|
||||||
return "Branch deleted successfully"
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
status = await self.delete_branch(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.branch,
|
|
||||||
)
|
|
||||||
yield "status", status
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubCompareBranchesBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
base: str = SchemaField(
|
|
||||||
description="Base branch or commit SHA",
|
|
||||||
placeholder="main",
|
|
||||||
)
|
|
||||||
head: str = SchemaField(
|
|
||||||
description="Head branch or commit SHA to compare against base",
|
|
||||||
placeholder="feature-branch",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
class FileChange(TypedDict):
|
|
||||||
filename: str
|
|
||||||
status: str
|
|
||||||
additions: int
|
|
||||||
deletions: int
|
|
||||||
patch: str
|
|
||||||
|
|
||||||
status: str = SchemaField(
|
|
||||||
description="Comparison status: ahead, behind, diverged, or identical"
|
|
||||||
)
|
|
||||||
ahead_by: int = SchemaField(
|
|
||||||
description="Number of commits head is ahead of base"
|
|
||||||
)
|
|
||||||
behind_by: int = SchemaField(
|
|
||||||
description="Number of commits head is behind base"
|
|
||||||
)
|
|
||||||
total_commits: int = SchemaField(
|
|
||||||
description="Total number of commits in the comparison"
|
|
||||||
)
|
|
||||||
diff: str = SchemaField(description="Unified diff of all file changes")
|
|
||||||
file: FileChange = SchemaField(
|
|
||||||
title="Changed File", description="A changed file with its diff"
|
|
||||||
)
|
|
||||||
files: list[FileChange] = SchemaField(
|
|
||||||
description="List of changed files with their diffs"
|
|
||||||
)
|
|
||||||
error: str = SchemaField(description="Error message if comparison failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="2e4faa8c-6086-4546-ba77-172d1d560186",
|
|
||||||
description="This block compares two branches or commits in a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubCompareBranchesBlock.Input,
|
|
||||||
output_schema=GithubCompareBranchesBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"base": "main",
|
|
||||||
"head": "feature",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("status", "ahead"),
|
|
||||||
("ahead_by", 2),
|
|
||||||
("behind_by", 0),
|
|
||||||
("total_commits", 2),
|
|
||||||
("diff", "+++ b/file.py\n+new line"),
|
|
||||||
(
|
|
||||||
"files",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"filename": "file.py",
|
|
||||||
"status": "modified",
|
|
||||||
"additions": 1,
|
|
||||||
"deletions": 0,
|
|
||||||
"patch": "+new line",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"file",
|
|
||||||
{
|
|
||||||
"filename": "file.py",
|
|
||||||
"status": "modified",
|
|
||||||
"additions": 1,
|
|
||||||
"deletions": 0,
|
|
||||||
"patch": "+new line",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"compare_branches": lambda *args, **kwargs: {
|
|
||||||
"status": "ahead",
|
|
||||||
"ahead_by": 2,
|
|
||||||
"behind_by": 0,
|
|
||||||
"total_commits": 2,
|
|
||||||
"files": [
|
|
||||||
{
|
|
||||||
"filename": "file.py",
|
|
||||||
"status": "modified",
|
|
||||||
"additions": 1,
|
|
||||||
"deletions": 0,
|
|
||||||
"patch": "+new line",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def compare_branches(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
base: str,
|
|
||||||
head: str,
|
|
||||||
) -> dict:
|
|
||||||
api = get_api(credentials)
|
|
||||||
safe_base = quote(base, safe="")
|
|
||||||
safe_head = quote(head, safe="")
|
|
||||||
compare_url = repo_url + f"/compare/{safe_base}...{safe_head}"
|
|
||||||
response = await api.get(compare_url)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
data = await self.compare_branches(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.base,
|
|
||||||
input_data.head,
|
|
||||||
)
|
|
||||||
yield "status", data["status"]
|
|
||||||
yield "ahead_by", data["ahead_by"]
|
|
||||||
yield "behind_by", data["behind_by"]
|
|
||||||
yield "total_commits", data["total_commits"]
|
|
||||||
|
|
||||||
files: list[GithubCompareBranchesBlock.Output.FileChange] = [
|
|
||||||
GithubCompareBranchesBlock.Output.FileChange(
|
|
||||||
filename=f["filename"],
|
|
||||||
status=f["status"],
|
|
||||||
additions=f["additions"],
|
|
||||||
deletions=f["deletions"],
|
|
||||||
patch=f.get("patch", ""),
|
|
||||||
)
|
|
||||||
for f in data.get("files", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
# Build unified diff
|
|
||||||
diff_parts = []
|
|
||||||
for f in data.get("files", []):
|
|
||||||
patch = f.get("patch", "")
|
|
||||||
if patch:
|
|
||||||
diff_parts.append(f"+++ b/{f['filename']}\n{patch}")
|
|
||||||
yield "diff", "\n".join(diff_parts)
|
|
||||||
|
|
||||||
yield "files", files
|
|
||||||
for file in files:
|
|
||||||
yield "file", file
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
@@ -1,720 +0,0 @@
|
|||||||
import base64
|
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
|
|
||||||
from ._api import get_api
|
|
||||||
from ._auth import (
|
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
GithubCredentials,
|
|
||||||
GithubCredentialsField,
|
|
||||||
GithubCredentialsInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubReadFileBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
file_path: str = SchemaField(
|
|
||||||
description="Path to the file in the repository",
|
|
||||||
placeholder="path/to/file",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch to read from",
|
|
||||||
placeholder="branch_name",
|
|
||||||
default="main",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
text_content: str = SchemaField(
|
|
||||||
description="Content of the file (decoded as UTF-8 text)"
|
|
||||||
)
|
|
||||||
raw_content: str = SchemaField(
|
|
||||||
description="Raw base64-encoded content of the file"
|
|
||||||
)
|
|
||||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
|
||||||
error: str = SchemaField(description="Error message if reading the file failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
|
||||||
description="This block reads the content of a specified file from a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubReadFileBlock.Input,
|
|
||||||
output_schema=GithubReadFileBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"file_path": "path/to/file",
|
|
||||||
"branch": "main",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("raw_content", "RmlsZSBjb250ZW50"),
|
|
||||||
("text_content", "File content"),
|
|
||||||
("size", 13),
|
|
||||||
],
|
|
||||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def read_file(
|
|
||||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
|
||||||
) -> tuple[str, int]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
content_url = (
|
|
||||||
repo_url
|
|
||||||
+ f"/contents/{quote(file_path, safe='')}?ref={quote(branch, safe='')}"
|
|
||||||
)
|
|
||||||
response = await api.get(content_url)
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if isinstance(data, list):
|
|
||||||
# Multiple entries of different types exist at this path
|
|
||||||
if not (file := next((f for f in data if f["type"] == "file"), None)):
|
|
||||||
raise TypeError("Not a file")
|
|
||||||
data = file
|
|
||||||
|
|
||||||
if data["type"] != "file":
|
|
||||||
raise TypeError("Not a file")
|
|
||||||
|
|
||||||
return data["content"], data["size"]
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
content, size = await self.read_file(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.file_path,
|
|
||||||
input_data.branch,
|
|
||||||
)
|
|
||||||
yield "raw_content", content
|
|
||||||
yield "text_content", base64.b64decode(content).decode("utf-8")
|
|
||||||
yield "size", size
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubReadFolderBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
folder_path: str = SchemaField(
|
|
||||||
description="Path to the folder in the repository",
|
|
||||||
placeholder="path/to/folder",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch name to read from (defaults to main)",
|
|
||||||
placeholder="branch_name",
|
|
||||||
default="main",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
class DirEntry(TypedDict):
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
|
|
||||||
class FileEntry(TypedDict):
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
size: int
|
|
||||||
|
|
||||||
file: FileEntry = SchemaField(description="Files in the folder")
|
|
||||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if reading the folder failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
|
||||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubReadFolderBlock.Input,
|
|
||||||
output_schema=GithubReadFolderBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"folder_path": "path/to/folder",
|
|
||||||
"branch": "main",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"file",
|
|
||||||
{
|
|
||||||
"name": "file1.txt",
|
|
||||||
"path": "path/to/folder/file1.txt",
|
|
||||||
"size": 1337,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"read_folder": lambda *args, **kwargs: (
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "file1.txt",
|
|
||||||
"path": "path/to/folder/file1.txt",
|
|
||||||
"size": 1337,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def read_folder(
|
|
||||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
|
||||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
contents_url = (
|
|
||||||
repo_url
|
|
||||||
+ f"/contents/{quote(folder_path, safe='/')}?ref={quote(branch, safe='')}"
|
|
||||||
)
|
|
||||||
response = await api.get(contents_url)
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not isinstance(data, list):
|
|
||||||
raise TypeError("Not a folder")
|
|
||||||
|
|
||||||
files: list[GithubReadFolderBlock.Output.FileEntry] = [
|
|
||||||
GithubReadFolderBlock.Output.FileEntry(
|
|
||||||
name=entry["name"],
|
|
||||||
path=entry["path"],
|
|
||||||
size=entry["size"],
|
|
||||||
)
|
|
||||||
for entry in data
|
|
||||||
if entry["type"] == "file"
|
|
||||||
]
|
|
||||||
|
|
||||||
dirs: list[GithubReadFolderBlock.Output.DirEntry] = [
|
|
||||||
GithubReadFolderBlock.Output.DirEntry(
|
|
||||||
name=entry["name"],
|
|
||||||
path=entry["path"],
|
|
||||||
)
|
|
||||||
for entry in data
|
|
||||||
if entry["type"] == "dir"
|
|
||||||
]
|
|
||||||
|
|
||||||
return files, dirs
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
files, dirs = await self.read_folder(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.folder_path.lstrip("/"),
|
|
||||||
input_data.branch,
|
|
||||||
)
|
|
||||||
for file in files:
|
|
||||||
yield "file", file
|
|
||||||
for dir in dirs:
|
|
||||||
yield "dir", dir
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubCreateFileBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
file_path: str = SchemaField(
|
|
||||||
description="Path where the file should be created",
|
|
||||||
placeholder="path/to/file.txt",
|
|
||||||
)
|
|
||||||
content: str = SchemaField(
|
|
||||||
description="Content to write to the file",
|
|
||||||
placeholder="File content here",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch where the file should be created",
|
|
||||||
default="main",
|
|
||||||
)
|
|
||||||
commit_message: str = SchemaField(
|
|
||||||
description="Message for the commit",
|
|
||||||
default="Create new file",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
url: str = SchemaField(description="URL of the created file")
|
|
||||||
sha: str = SchemaField(description="SHA of the commit")
|
|
||||||
error: str = SchemaField(
|
|
||||||
description="Error message if the file creation failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
|
||||||
description="This block creates a new file in a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubCreateFileBlock.Input,
|
|
||||||
output_schema=GithubCreateFileBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"file_path": "test/file.txt",
|
|
||||||
"content": "Test content",
|
|
||||||
"branch": "main",
|
|
||||||
"commit_message": "Create test file",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
|
||||||
("sha", "abc123"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"create_file": lambda *args, **kwargs: (
|
|
||||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
|
||||||
"abc123",
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create_file(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
file_path: str,
|
|
||||||
content: str,
|
|
||||||
branch: str,
|
|
||||||
commit_message: str,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
|
||||||
content_base64 = base64.b64encode(content.encode()).decode()
|
|
||||||
data = {
|
|
||||||
"message": commit_message,
|
|
||||||
"content": content_base64,
|
|
||||||
"branch": branch,
|
|
||||||
}
|
|
||||||
response = await api.put(contents_url, json=data)
|
|
||||||
data = response.json()
|
|
||||||
return data["content"]["html_url"], data["commit"]["sha"]
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
url, sha = await self.create_file(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.file_path,
|
|
||||||
input_data.content,
|
|
||||||
input_data.branch,
|
|
||||||
input_data.commit_message,
|
|
||||||
)
|
|
||||||
yield "url", url
|
|
||||||
yield "sha", sha
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubUpdateFileBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
file_path: str = SchemaField(
|
|
||||||
description="Path to the file to update",
|
|
||||||
placeholder="path/to/file.txt",
|
|
||||||
)
|
|
||||||
content: str = SchemaField(
|
|
||||||
description="New content for the file",
|
|
||||||
placeholder="Updated content here",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch containing the file",
|
|
||||||
default="main",
|
|
||||||
)
|
|
||||||
commit_message: str = SchemaField(
|
|
||||||
description="Message for the commit",
|
|
||||||
default="Update file",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
url: str = SchemaField(description="URL of the updated file")
|
|
||||||
sha: str = SchemaField(description="SHA of the commit")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
|
||||||
description="This block updates an existing file in a GitHub repository.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubUpdateFileBlock.Input,
|
|
||||||
output_schema=GithubUpdateFileBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"file_path": "test/file.txt",
|
|
||||||
"content": "Updated content",
|
|
||||||
"branch": "main",
|
|
||||||
"commit_message": "Update test file",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
|
||||||
("sha", "def456"),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"update_file": lambda *args, **kwargs: (
|
|
||||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
|
||||||
"def456",
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_file(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
file_path: str,
|
|
||||||
content: str,
|
|
||||||
branch: str,
|
|
||||||
commit_message: str,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
contents_url = repo_url + f"/contents/{quote(file_path, safe='/')}"
|
|
||||||
params = {"ref": branch}
|
|
||||||
response = await api.get(contents_url, params=params)
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
# Convert new content to base64
|
|
||||||
content_base64 = base64.b64encode(content.encode()).decode()
|
|
||||||
data = {
|
|
||||||
"message": commit_message,
|
|
||||||
"content": content_base64,
|
|
||||||
"sha": data["sha"],
|
|
||||||
"branch": branch,
|
|
||||||
}
|
|
||||||
response = await api.put(contents_url, json=data)
|
|
||||||
data = response.json()
|
|
||||||
return data["content"]["html_url"], data["commit"]["sha"]
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
url, sha = await self.update_file(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.file_path,
|
|
||||||
input_data.content,
|
|
||||||
input_data.branch,
|
|
||||||
input_data.commit_message,
|
|
||||||
)
|
|
||||||
yield "url", url
|
|
||||||
yield "sha", sha
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubSearchCodeBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
query: str = SchemaField(
|
|
||||||
description="Search query (GitHub code search syntax)",
|
|
||||||
placeholder="className language:python",
|
|
||||||
)
|
|
||||||
repo: str = SchemaField(
|
|
||||||
description="Restrict search to a repository (owner/repo format, optional)",
|
|
||||||
default="",
|
|
||||||
placeholder="owner/repo",
|
|
||||||
)
|
|
||||||
per_page: int = SchemaField(
|
|
||||||
description="Number of results to return (max 100)",
|
|
||||||
default=30,
|
|
||||||
ge=1,
|
|
||||||
le=100,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
class SearchResult(TypedDict):
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
repository: str
|
|
||||||
url: str
|
|
||||||
score: float
|
|
||||||
|
|
||||||
result: SearchResult = SchemaField(
|
|
||||||
title="Result", description="A code search result"
|
|
||||||
)
|
|
||||||
results: list[SearchResult] = SchemaField(
|
|
||||||
description="List of code search results"
|
|
||||||
)
|
|
||||||
total_count: int = SchemaField(description="Total number of matching results")
|
|
||||||
error: str = SchemaField(description="Error message if search failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="47f94891-a2b1-4f1c-b5f2-573c043f721e",
|
|
||||||
description="This block searches for code in GitHub repositories.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubSearchCodeBlock.Input,
|
|
||||||
output_schema=GithubSearchCodeBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"query": "addClass",
|
|
||||||
"repo": "owner/repo",
|
|
||||||
"per_page": 30,
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("total_count", 1),
|
|
||||||
(
|
|
||||||
"results",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "file.py",
|
|
||||||
"path": "src/file.py",
|
|
||||||
"repository": "owner/repo",
|
|
||||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
|
||||||
"score": 1.0,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"result",
|
|
||||||
{
|
|
||||||
"name": "file.py",
|
|
||||||
"path": "src/file.py",
|
|
||||||
"repository": "owner/repo",
|
|
||||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
|
||||||
"score": 1.0,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"search_code": lambda *args, **kwargs: (
|
|
||||||
1,
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "file.py",
|
|
||||||
"path": "src/file.py",
|
|
||||||
"repository": "owner/repo",
|
|
||||||
"url": "https://github.com/owner/repo/blob/main/src/file.py",
|
|
||||||
"score": 1.0,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def search_code(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
query: str,
|
|
||||||
repo: str,
|
|
||||||
per_page: int,
|
|
||||||
) -> tuple[int, list[Output.SearchResult]]:
|
|
||||||
api = get_api(credentials, convert_urls=False)
|
|
||||||
full_query = f"{query} repo:{repo}" if repo else query
|
|
||||||
params = {"q": full_query, "per_page": str(per_page)}
|
|
||||||
response = await api.get("https://api.github.com/search/code", params=params)
|
|
||||||
data = response.json()
|
|
||||||
results: list[GithubSearchCodeBlock.Output.SearchResult] = [
|
|
||||||
GithubSearchCodeBlock.Output.SearchResult(
|
|
||||||
name=item["name"],
|
|
||||||
path=item["path"],
|
|
||||||
repository=item["repository"]["full_name"],
|
|
||||||
url=item["html_url"],
|
|
||||||
score=item["score"],
|
|
||||||
)
|
|
||||||
for item in data["items"]
|
|
||||||
]
|
|
||||||
return data["total_count"], results
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
total_count, results = await self.search_code(
|
|
||||||
credentials,
|
|
||||||
input_data.query,
|
|
||||||
input_data.repo,
|
|
||||||
input_data.per_page,
|
|
||||||
)
|
|
||||||
yield "total_count", total_count
|
|
||||||
yield "results", results
|
|
||||||
for result in results:
|
|
||||||
yield "result", result
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubGetRepositoryTreeBlock(Block):
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
|
||||||
repo_url: str = SchemaField(
|
|
||||||
description="URL of the GitHub repository",
|
|
||||||
placeholder="https://github.com/owner/repo",
|
|
||||||
)
|
|
||||||
branch: str = SchemaField(
|
|
||||||
description="Branch name to get the tree from",
|
|
||||||
default="main",
|
|
||||||
)
|
|
||||||
recursive: bool = SchemaField(
|
|
||||||
description="Whether to recursively list the entire tree",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
class TreeEntry(TypedDict):
|
|
||||||
path: str
|
|
||||||
type: str
|
|
||||||
size: int
|
|
||||||
sha: str
|
|
||||||
|
|
||||||
entry: TreeEntry = SchemaField(
|
|
||||||
title="Tree Entry", description="A file or directory in the tree"
|
|
||||||
)
|
|
||||||
entries: list[TreeEntry] = SchemaField(
|
|
||||||
description="List of all files and directories in the tree"
|
|
||||||
)
|
|
||||||
truncated: bool = SchemaField(
|
|
||||||
description="Whether the tree was truncated due to size"
|
|
||||||
)
|
|
||||||
error: str = SchemaField(description="Error message if getting tree failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="89c5c0ec-172e-4001-a32c-bdfe4d0c9e81",
|
|
||||||
description="This block lists the entire file tree of a GitHub repository recursively.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=GithubGetRepositoryTreeBlock.Input,
|
|
||||||
output_schema=GithubGetRepositoryTreeBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"branch": "main",
|
|
||||||
"recursive": True,
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("truncated", False),
|
|
||||||
(
|
|
||||||
"entries",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"path": "src/main.py",
|
|
||||||
"type": "blob",
|
|
||||||
"size": 1234,
|
|
||||||
"sha": "abc123",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"entry",
|
|
||||||
{
|
|
||||||
"path": "src/main.py",
|
|
||||||
"type": "blob",
|
|
||||||
"size": 1234,
|
|
||||||
"sha": "abc123",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"get_tree": lambda *args, **kwargs: (
|
|
||||||
False,
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"path": "src/main.py",
|
|
||||||
"type": "blob",
|
|
||||||
"size": 1234,
|
|
||||||
"sha": "abc123",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_tree(
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
repo_url: str,
|
|
||||||
branch: str,
|
|
||||||
recursive: bool,
|
|
||||||
) -> tuple[bool, list[Output.TreeEntry]]:
|
|
||||||
api = get_api(credentials)
|
|
||||||
tree_url = repo_url + f"/git/trees/{quote(branch, safe='')}"
|
|
||||||
params = {"recursive": "1"} if recursive else {}
|
|
||||||
response = await api.get(tree_url, params=params)
|
|
||||||
data = response.json()
|
|
||||||
entries: list[GithubGetRepositoryTreeBlock.Output.TreeEntry] = [
|
|
||||||
GithubGetRepositoryTreeBlock.Output.TreeEntry(
|
|
||||||
path=item["path"],
|
|
||||||
type=item["type"],
|
|
||||||
size=item.get("size", 0),
|
|
||||||
sha=item["sha"],
|
|
||||||
)
|
|
||||||
for item in data["tree"]
|
|
||||||
]
|
|
||||||
return data.get("truncated", False), entries
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: GithubCredentials,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
try:
|
|
||||||
truncated, entries = await self.get_tree(
|
|
||||||
credentials,
|
|
||||||
input_data.repo_url,
|
|
||||||
input_data.branch,
|
|
||||||
input_data.recursive,
|
|
||||||
)
|
|
||||||
yield "truncated", truncated
|
|
||||||
yield "entries", entries
|
|
||||||
for entry in entries:
|
|
||||||
yield "entry", entry
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
import inspect
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.github._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
|
||||||
from backend.blocks.github.commits import FileOperation, GithubMultiFileCommitBlock
|
|
||||||
from backend.blocks.github.pull_requests import (
|
|
||||||
GithubMergePullRequestBlock,
|
|
||||||
prepare_pr_api_url,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.util.exceptions import BlockExecutionError
|
|
||||||
|
|
||||||
# ── prepare_pr_api_url tests ──
|
|
||||||
|
|
||||||
|
|
||||||
class TestPreparePrApiUrl:
|
|
||||||
def test_https_scheme_preserved(self):
|
|
||||||
result = prepare_pr_api_url("https://github.com/owner/repo/pull/42", "merge")
|
|
||||||
assert result == "https://github.com/owner/repo/pulls/42/merge"
|
|
||||||
|
|
||||||
def test_http_scheme_preserved(self):
|
|
||||||
result = prepare_pr_api_url("http://github.com/owner/repo/pull/1", "files")
|
|
||||||
assert result == "http://github.com/owner/repo/pulls/1/files"
|
|
||||||
|
|
||||||
def test_no_scheme_defaults_to_https(self):
|
|
||||||
result = prepare_pr_api_url("github.com/owner/repo/pull/5", "merge")
|
|
||||||
assert result == "https://github.com/owner/repo/pulls/5/merge"
|
|
||||||
|
|
||||||
def test_reviewers_path(self):
|
|
||||||
result = prepare_pr_api_url(
|
|
||||||
"https://github.com/owner/repo/pull/99", "requested_reviewers"
|
|
||||||
)
|
|
||||||
assert result == "https://github.com/owner/repo/pulls/99/requested_reviewers"
|
|
||||||
|
|
||||||
def test_invalid_url_returned_as_is(self):
|
|
||||||
url = "https://example.com/not-a-pr"
|
|
||||||
assert prepare_pr_api_url(url, "merge") == url
|
|
||||||
|
|
||||||
def test_empty_string(self):
|
|
||||||
assert prepare_pr_api_url("", "merge") == ""
|
|
||||||
|
|
||||||
|
|
||||||
# ── Error-path block tests ──
|
|
||||||
# When a block's run() yields ("error", msg), _execute() converts it to a
|
|
||||||
# BlockExecutionError. We call block.execute() directly (not execute_block_test,
|
|
||||||
# which returns early on empty test_output).
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_block(block, mocks: dict):
|
|
||||||
"""Apply mocks to a block's static methods, wrapping sync mocks as async."""
|
|
||||||
for name, mock_fn in mocks.items():
|
|
||||||
original = getattr(block, name)
|
|
||||||
if inspect.iscoroutinefunction(original):
|
|
||||||
|
|
||||||
async def async_mock(*args, _fn=mock_fn, **kwargs):
|
|
||||||
return _fn(*args, **kwargs)
|
|
||||||
|
|
||||||
setattr(block, name, async_mock)
|
|
||||||
else:
|
|
||||||
setattr(block, name, mock_fn)
|
|
||||||
|
|
||||||
|
|
||||||
def _raise(exc: Exception):
|
|
||||||
"""Helper that returns a callable which raises the given exception."""
|
|
||||||
|
|
||||||
def _raiser(*args, **kwargs):
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
return _raiser
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_merge_pr_error_path():
|
|
||||||
block = GithubMergePullRequestBlock()
|
|
||||||
_mock_block(block, {"merge_pr": _raise(RuntimeError("PR not mergeable"))})
|
|
||||||
input_data = {
|
|
||||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
|
||||||
"merge_method": "squash",
|
|
||||||
"commit_title": "",
|
|
||||||
"commit_message": "",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
}
|
|
||||||
with pytest.raises(BlockExecutionError, match="PR not mergeable"):
|
|
||||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multi_file_commit_error_path():
|
|
||||||
block = GithubMultiFileCommitBlock()
|
|
||||||
_mock_block(block, {"multi_file_commit": _raise(RuntimeError("ref update failed"))})
|
|
||||||
input_data = {
|
|
||||||
"repo_url": "https://github.com/owner/repo",
|
|
||||||
"branch": "feature",
|
|
||||||
"commit_message": "test",
|
|
||||||
"files": [{"path": "a.py", "content": "x", "operation": "upsert"}],
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
}
|
|
||||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
|
||||||
async for _ in block.execute(
|
|
||||||
input_data,
|
|
||||||
credentials=TEST_CREDENTIALS,
|
|
||||||
execution_context=ExecutionContext(),
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ── FileOperation enum tests ──
|
|
||||||
|
|
||||||
|
|
||||||
class TestFileOperation:
|
|
||||||
def test_upsert_value(self):
|
|
||||||
assert FileOperation.UPSERT == "upsert"
|
|
||||||
|
|
||||||
def test_delete_value(self):
|
|
||||||
assert FileOperation.DELETE == "delete"
|
|
||||||
|
|
||||||
def test_invalid_value_raises(self):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
FileOperation("create")
|
|
||||||
|
|
||||||
def test_invalid_value_raises_typo(self):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
FileOperation("upser")
|
|
||||||
@@ -241,8 +241,8 @@ class GmailBase(Block, ABC):
|
|||||||
h.ignore_links = False
|
h.ignore_links = False
|
||||||
h.ignore_images = True
|
h.ignore_images = True
|
||||||
return h.handle(html_content)
|
return h.handle(html_content)
|
||||||
except Exception:
|
except ImportError:
|
||||||
# Keep extraction resilient if html2text is unavailable or fails.
|
# Fallback: return raw HTML if html2text is not available
|
||||||
return html_content
|
return html_content
|
||||||
|
|
||||||
# Handle content stored as attachment
|
# Handle content stored as attachment
|
||||||
|
|||||||
@@ -67,7 +67,6 @@ class HITLReviewHelper:
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
block_name: str = "Block",
|
block_name: str = "Block",
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
is_graph_execution: bool = True,
|
|
||||||
) -> Optional[ReviewResult]:
|
) -> Optional[ReviewResult]:
|
||||||
"""
|
"""
|
||||||
Handle a review request for a block that requires human review.
|
Handle a review request for a block that requires human review.
|
||||||
@@ -144,11 +143,10 @@ class HITLReviewHelper:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||||
)
|
)
|
||||||
if is_graph_execution:
|
await HITLReviewHelper.update_node_execution_status(
|
||||||
await HITLReviewHelper.update_node_execution_status(
|
exec_id=node_exec_id,
|
||||||
exec_id=node_exec_id,
|
status=ExecutionStatus.REVIEW,
|
||||||
status=ExecutionStatus.REVIEW,
|
)
|
||||||
)
|
|
||||||
return None # Signal that execution should pause
|
return None # Signal that execution should pause
|
||||||
|
|
||||||
# Mark review as processed if not already done
|
# Mark review as processed if not already done
|
||||||
@@ -170,7 +168,6 @@ class HITLReviewHelper:
|
|||||||
graph_version: int,
|
graph_version: int,
|
||||||
block_name: str = "Block",
|
block_name: str = "Block",
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
is_graph_execution: bool = True,
|
|
||||||
) -> Optional[ReviewDecision]:
|
) -> Optional[ReviewDecision]:
|
||||||
"""
|
"""
|
||||||
Handle a review request and return the decision in a single call.
|
Handle a review request and return the decision in a single call.
|
||||||
@@ -200,7 +197,6 @@ class HITLReviewHelper:
|
|||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
block_name=block_name,
|
block_name=block_name,
|
||||||
editable=editable,
|
editable=editable,
|
||||||
is_graph_execution=is_graph_execution,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if review_result is None:
|
if review_result is None:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from backend.blocks.jina._auth import (
|
|||||||
from backend.blocks.search import GetRequest
|
from backend.blocks.search import GetRequest
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url_host
|
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
|
||||||
|
|
||||||
|
|
||||||
class SearchTheWebBlock(Block, GetRequest):
|
class SearchTheWebBlock(Block, GetRequest):
|
||||||
@@ -112,7 +112,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
if input_data.raw_content:
|
if input_data.raw_content:
|
||||||
try:
|
try:
|
||||||
parsed_url, _, _ = await validate_url_host(input_data.url)
|
parsed_url, _, _ = await validate_url(input_data.url, [])
|
||||||
url = parsed_url.geturl()
|
url = parsed_url.geturl()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
yield "error", f"Invalid URL: {e}"
|
yield "error", f"Invalid URL: {e}"
|
||||||
|
|||||||
@@ -31,14 +31,10 @@ from backend.data.model import (
|
|||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.clients import OPENROUTER_BASE_URL
|
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
from backend.util.prompt import compress_context, estimate_token_count
|
from backend.util.prompt import compress_context, estimate_token_count
|
||||||
from backend.util.request import validate_url_host
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
from backend.util.text import TextFormatter
|
from backend.util.text import TextFormatter
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||||
fmt = TextFormatter(autoescape=False)
|
fmt = TextFormatter(autoescape=False)
|
||||||
|
|
||||||
@@ -120,7 +116,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -140,31 +135,19 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
# OpenRouter models
|
# OpenRouter models
|
||||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||||
GEMINI_2_5_PRO_PREVIEW = "google/gemini-2.5-pro-preview-03-25"
|
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro"
|
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||||
GEMINI_3_1_PRO_PREVIEW = "google/gemini-3.1-pro-preview"
|
|
||||||
GEMINI_3_FLASH_PREVIEW = "google/gemini-3-flash-preview"
|
|
||||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||||
GEMINI_3_1_FLASH_LITE_PREVIEW = "google/gemini-3.1-flash-lite-preview"
|
|
||||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||||
MISTRAL_LARGE_3 = "mistralai/mistral-large-2512"
|
|
||||||
MISTRAL_MEDIUM_3_1 = "mistralai/mistral-medium-3.1"
|
|
||||||
MISTRAL_SMALL_3_2 = "mistralai/mistral-small-3.2-24b-instruct"
|
|
||||||
CODESTRAL = "mistralai/codestral-2508"
|
|
||||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||||
COHERE_COMMAND_A_03_2025 = "cohere/command-a-03-2025"
|
|
||||||
COHERE_COMMAND_A_TRANSLATE_08_2025 = "cohere/command-a-translate-08-2025"
|
|
||||||
COHERE_COMMAND_A_REASONING_08_2025 = "cohere/command-a-reasoning-08-2025"
|
|
||||||
COHERE_COMMAND_A_VISION_07_2025 = "cohere/command-a-vision-07-2025"
|
|
||||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||||
PERPLEXITY_SONAR_REASONING_PRO = "perplexity/sonar-reasoning-pro"
|
|
||||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||||
@@ -172,11 +155,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||||
MICROSOFT_PHI_4 = "microsoft/phi-4"
|
|
||||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||||
GROK_3 = "x-ai/grok-3"
|
|
||||||
GROK_4 = "x-ai/grok-4"
|
GROK_4 = "x-ai/grok-4"
|
||||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||||
@@ -293,9 +274,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||||
), # claude-opus-4-6
|
), # claude-opus-4-6
|
||||||
LlmModel.CLAUDE_4_6_SONNET: ModelMetadata(
|
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4.6", "Anthropic", "Anthropic", 3
|
|
||||||
), # claude-sonnet-4-6
|
|
||||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||||
), # claude-opus-4-5-20251101
|
), # claude-opus-4-5-20251101
|
||||||
@@ -354,41 +332,17 @@ MODEL_METADATA = {
|
|||||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||||
),
|
),
|
||||||
# https://openrouter.ai/models
|
# https://openrouter.ai/models
|
||||||
LlmModel.GEMINI_2_5_PRO_PREVIEW: ModelMetadata(
|
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||||
"open_router",
|
"open_router",
|
||||||
1048576,
|
1050000,
|
||||||
65536,
|
8192,
|
||||||
"Gemini 2.5 Pro Preview 03.25",
|
"Gemini 2.5 Pro Preview 03.25",
|
||||||
"OpenRouter",
|
"OpenRouter",
|
||||||
"Google",
|
"Google",
|
||||||
2,
|
2,
|
||||||
),
|
),
|
||||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||||
"open_router",
|
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||||
1048576,
|
|
||||||
65536,
|
|
||||||
"Gemini 2.5 Pro",
|
|
||||||
"OpenRouter",
|
|
||||||
"Google",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
LlmModel.GEMINI_3_1_PRO_PREVIEW: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
1048576,
|
|
||||||
65536,
|
|
||||||
"Gemini 3.1 Pro Preview",
|
|
||||||
"OpenRouter",
|
|
||||||
"Google",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
LlmModel.GEMINI_3_FLASH_PREVIEW: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
1048576,
|
|
||||||
65536,
|
|
||||||
"Gemini 3 Flash Preview",
|
|
||||||
"OpenRouter",
|
|
||||||
"Google",
|
|
||||||
1,
|
|
||||||
),
|
),
|
||||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||||
@@ -396,15 +350,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||||
),
|
),
|
||||||
LlmModel.GEMINI_3_1_FLASH_LITE_PREVIEW: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
1048576,
|
|
||||||
65536,
|
|
||||||
"Gemini 3.1 Flash Lite Preview",
|
|
||||||
"OpenRouter",
|
|
||||||
"Google",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||||
"open_router",
|
"open_router",
|
||||||
1048576,
|
1048576,
|
||||||
@@ -426,78 +371,12 @@ MODEL_METADATA = {
|
|||||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||||
),
|
),
|
||||||
LlmModel.MISTRAL_LARGE_3: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
262144,
|
|
||||||
None,
|
|
||||||
"Mistral Large 3 2512",
|
|
||||||
"OpenRouter",
|
|
||||||
"Mistral AI",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
LlmModel.MISTRAL_MEDIUM_3_1: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
131072,
|
|
||||||
None,
|
|
||||||
"Mistral Medium 3.1",
|
|
||||||
"OpenRouter",
|
|
||||||
"Mistral AI",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
LlmModel.MISTRAL_SMALL_3_2: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
131072,
|
|
||||||
131072,
|
|
||||||
"Mistral Small 3.2 24B",
|
|
||||||
"OpenRouter",
|
|
||||||
"Mistral AI",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.CODESTRAL: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
256000,
|
|
||||||
None,
|
|
||||||
"Codestral 2508",
|
|
||||||
"OpenRouter",
|
|
||||||
"Mistral AI",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||||
),
|
),
|
||||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||||
),
|
),
|
||||||
LlmModel.COHERE_COMMAND_A_03_2025: ModelMetadata(
|
|
||||||
"open_router", 256000, 8192, "Command A 03.2025", "OpenRouter", "Cohere", 2
|
|
||||||
),
|
|
||||||
LlmModel.COHERE_COMMAND_A_TRANSLATE_08_2025: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
128000,
|
|
||||||
8192,
|
|
||||||
"Command A Translate 08.2025",
|
|
||||||
"OpenRouter",
|
|
||||||
"Cohere",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
LlmModel.COHERE_COMMAND_A_REASONING_08_2025: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
256000,
|
|
||||||
32768,
|
|
||||||
"Command A Reasoning 08.2025",
|
|
||||||
"OpenRouter",
|
|
||||||
"Cohere",
|
|
||||||
3,
|
|
||||||
),
|
|
||||||
LlmModel.COHERE_COMMAND_A_VISION_07_2025: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
128000,
|
|
||||||
8192,
|
|
||||||
"Command A Vision 07.2025",
|
|
||||||
"OpenRouter",
|
|
||||||
"Cohere",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||||
),
|
),
|
||||||
@@ -510,15 +389,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||||
),
|
),
|
||||||
LlmModel.PERPLEXITY_SONAR_REASONING_PRO: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
128000,
|
|
||||||
8000,
|
|
||||||
"Sonar Reasoning Pro",
|
|
||||||
"OpenRouter",
|
|
||||||
"Perplexity",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||||
"open_router",
|
"open_router",
|
||||||
128000,
|
128000,
|
||||||
@@ -564,9 +434,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||||
),
|
),
|
||||||
LlmModel.MICROSOFT_PHI_4: ModelMetadata(
|
|
||||||
"open_router", 16384, 16384, "Phi-4", "OpenRouter", "Microsoft", 1
|
|
||||||
),
|
|
||||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||||
),
|
),
|
||||||
@@ -576,15 +443,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||||
),
|
),
|
||||||
LlmModel.GROK_3: ModelMetadata(
|
|
||||||
"open_router",
|
|
||||||
131072,
|
|
||||||
131072,
|
|
||||||
"Grok 3",
|
|
||||||
"OpenRouter",
|
|
||||||
"xAI",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
LlmModel.GROK_4: ModelMetadata(
|
LlmModel.GROK_4: ModelMetadata(
|
||||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||||
),
|
),
|
||||||
@@ -942,11 +800,6 @@ async def llm_call(
|
|||||||
if tools:
|
if tools:
|
||||||
raise ValueError("Ollama does not support tools.")
|
raise ValueError("Ollama does not support tools.")
|
||||||
|
|
||||||
# Validate user-provided Ollama host to prevent SSRF etc.
|
|
||||||
await validate_url_host(
|
|
||||||
ollama_host, trusted_hostnames=[settings.config.ollama_host]
|
|
||||||
)
|
|
||||||
|
|
||||||
client = ollama.AsyncClient(host=ollama_host)
|
client = ollama.AsyncClient(host=ollama_host)
|
||||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||||
@@ -968,7 +821,7 @@ async def llm_call(
|
|||||||
elif provider == "open_router":
|
elif provider == "open_router":
|
||||||
tools_param = tools if tools else openai.NOT_GIVEN
|
tools_param = tools if tools else openai.NOT_GIVEN
|
||||||
client = openai.AsyncOpenAI(
|
client = openai.AsyncOpenAI(
|
||||||
base_url=OPENROUTER_BASE_URL,
|
base_url="https://openrouter.ai/api/v1",
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ and execute them. Works like AgentExecutorBlock — the user selects a tool from
|
|||||||
dropdown and the input/output schema adapts dynamically.
|
dropdown and the input/output schema adapts dynamically.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@@ -19,11 +20,6 @@ from backend.blocks._base import (
|
|||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||||
from backend.blocks.mcp.helpers import (
|
|
||||||
auto_lookup_mcp_credential,
|
|
||||||
normalize_mcp_url,
|
|
||||||
parse_mcp_content,
|
|
||||||
)
|
|
||||||
from backend.data.block import BlockInput, BlockOutput
|
from backend.data.block import BlockInput, BlockOutput
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -183,7 +179,31 @@ class MCPToolBlock(Block):
|
|||||||
f"{error_text or 'Unknown error'}"
|
f"{error_text or 'Unknown error'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return parse_mcp_content(result.content)
|
# Extract text content from the result
|
||||||
|
output_parts = []
|
||||||
|
for item in result.content:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
# Try to parse as JSON for structured output
|
||||||
|
try:
|
||||||
|
output_parts.append(json.loads(text))
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
output_parts.append(text)
|
||||||
|
elif item.get("type") == "image":
|
||||||
|
output_parts.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": item.get("data"),
|
||||||
|
"mimeType": item.get("mimeType"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif item.get("type") == "resource":
|
||||||
|
output_parts.append(item.get("resource", {}))
|
||||||
|
|
||||||
|
# If single result, unwrap
|
||||||
|
if len(output_parts) == 1:
|
||||||
|
return output_parts[0]
|
||||||
|
return output_parts if output_parts else None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _auto_lookup_credential(
|
async def _auto_lookup_credential(
|
||||||
@@ -191,10 +211,37 @@ class MCPToolBlock(Block):
|
|||||||
) -> "OAuth2Credentials | None":
|
) -> "OAuth2Credentials | None":
|
||||||
"""Auto-lookup stored MCP credential for a server URL.
|
"""Auto-lookup stored MCP credential for a server URL.
|
||||||
|
|
||||||
Delegates to :func:`~backend.blocks.mcp.helpers.auto_lookup_mcp_credential`.
|
This is a fallback for nodes that don't have ``credentials`` explicitly
|
||||||
The caller should pass a normalized URL.
|
set (e.g. nodes created before the credential field was wired up).
|
||||||
"""
|
"""
|
||||||
return await auto_lookup_mcp_credential(user_id, server_url)
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
try:
|
||||||
|
mgr = IntegrationCredentialsManager()
|
||||||
|
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||||
|
user_id, ProviderName.MCP.value
|
||||||
|
)
|
||||||
|
best: OAuth2Credentials | None = None
|
||||||
|
for cred in mcp_creds:
|
||||||
|
if (
|
||||||
|
isinstance(cred, OAuth2Credentials)
|
||||||
|
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
||||||
|
):
|
||||||
|
if best is None or (
|
||||||
|
(cred.access_token_expires_at or 0)
|
||||||
|
> (best.access_token_expires_at or 0)
|
||||||
|
):
|
||||||
|
best = cred
|
||||||
|
if best:
|
||||||
|
best = await mgr.refresh_if_needed(user_id, best)
|
||||||
|
logger.info(
|
||||||
|
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
||||||
|
)
|
||||||
|
return best
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@@ -231,7 +278,7 @@ class MCPToolBlock(Block):
|
|||||||
# the stored MCP credential for this server URL.
|
# the stored MCP credential for this server URL.
|
||||||
if credentials is None:
|
if credentials is None:
|
||||||
credentials = await self._auto_lookup_credential(
|
credentials = await self._auto_lookup_credential(
|
||||||
user_id, normalize_mcp_url(input_data.server_url)
|
user_id, input_data.server_url
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_token = (
|
auth_token = (
|
||||||
|
|||||||
@@ -55,9 +55,7 @@ class MCPClient:
|
|||||||
server_url: str,
|
server_url: str,
|
||||||
auth_token: str | None = None,
|
auth_token: str | None = None,
|
||||||
):
|
):
|
||||||
from backend.blocks.mcp.helpers import normalize_mcp_url
|
self.server_url = server_url.rstrip("/")
|
||||||
|
|
||||||
self.server_url = normalize_mcp_url(server_url)
|
|
||||||
self.auth_token = auth_token
|
self.auth_token = auth_token
|
||||||
self._request_id = 0
|
self._request_id = 0
|
||||||
self._session_id: str | None = None
|
self._session_id: str | None = None
|
||||||
|
|||||||
@@ -1,117 +0,0 @@
|
|||||||
"""Shared MCP helpers used by blocks, copilot tools, and API routes."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_mcp_url(url: str) -> str:
|
|
||||||
"""Normalize an MCP server URL for consistent credential matching.
|
|
||||||
|
|
||||||
Strips leading/trailing whitespace and a single trailing slash so that
|
|
||||||
``https://mcp.example.com/`` and ``https://mcp.example.com`` resolve to
|
|
||||||
the same stored credential.
|
|
||||||
"""
|
|
||||||
return url.strip().rstrip("/")
|
|
||||||
|
|
||||||
|
|
||||||
def server_host(server_url: str) -> str:
|
|
||||||
"""Extract the hostname from a server URL for display purposes.
|
|
||||||
|
|
||||||
Uses ``parsed.hostname`` (never ``netloc``) to strip any embedded
|
|
||||||
username/password before surfacing the value in UI messages.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
parsed = urlparse(server_url)
|
|
||||||
return parsed.hostname or server_url
|
|
||||||
except Exception:
|
|
||||||
return server_url
|
|
||||||
|
|
||||||
|
|
||||||
def parse_mcp_content(content: list[dict[str, Any]]) -> Any:
|
|
||||||
"""Parse MCP tool response content into a plain Python value.
|
|
||||||
|
|
||||||
- text items: parsed as JSON when possible, kept as str otherwise
|
|
||||||
- image items: kept as ``{type, data, mimeType}`` dict for frontend rendering
|
|
||||||
- resource items: unwrapped to their resource payload dict
|
|
||||||
|
|
||||||
Single-item responses are unwrapped from the list; multiple items are
|
|
||||||
returned as a list; empty content returns ``None``.
|
|
||||||
"""
|
|
||||||
output_parts: list[Any] = []
|
|
||||||
for item in content:
|
|
||||||
item_type = item.get("type")
|
|
||||||
if item_type == "text":
|
|
||||||
text = item.get("text", "")
|
|
||||||
try:
|
|
||||||
output_parts.append(json.loads(text))
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
output_parts.append(text)
|
|
||||||
elif item_type == "image":
|
|
||||||
output_parts.append(
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"data": item.get("data"),
|
|
||||||
"mimeType": item.get("mimeType"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif item_type == "resource":
|
|
||||||
output_parts.append(item.get("resource", {}))
|
|
||||||
|
|
||||||
if len(output_parts) == 1:
|
|
||||||
return output_parts[0]
|
|
||||||
return output_parts or None
|
|
||||||
|
|
||||||
|
|
||||||
async def auto_lookup_mcp_credential(
|
|
||||||
user_id: str, server_url: str
|
|
||||||
) -> OAuth2Credentials | None:
|
|
||||||
"""Look up the best stored MCP credential for *server_url*.
|
|
||||||
|
|
||||||
The caller should pass a **normalized** URL (via :func:`normalize_mcp_url`)
|
|
||||||
so the comparison with ``mcp_server_url`` in credential metadata matches.
|
|
||||||
|
|
||||||
Returns the credential with the latest ``access_token_expires_at``, refreshed
|
|
||||||
if needed, or ``None`` when no match is found.
|
|
||||||
"""
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
try:
|
|
||||||
mgr = IntegrationCredentialsManager()
|
|
||||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
# Collect all matching credentials and pick the best one.
|
|
||||||
# Primary sort: latest access_token_expires_at (tokens with expiry
|
|
||||||
# are preferred over non-expiring ones). Secondary sort: last in
|
|
||||||
# iteration order, which corresponds to the most recently created
|
|
||||||
# row — this acts as a tiebreaker when multiple bearer tokens have
|
|
||||||
# no expiry (e.g. after a failed old-credential cleanup).
|
|
||||||
best: OAuth2Credentials | None = None
|
|
||||||
for cred in mcp_creds:
|
|
||||||
if (
|
|
||||||
isinstance(cred, OAuth2Credentials)
|
|
||||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
|
||||||
):
|
|
||||||
if best is None or (
|
|
||||||
(cred.access_token_expires_at or 0)
|
|
||||||
>= (best.access_token_expires_at or 0)
|
|
||||||
):
|
|
||||||
best = cred
|
|
||||||
if best:
|
|
||||||
best = await mgr.refresh_if_needed(user_id, best)
|
|
||||||
logger.info("Auto-resolved MCP credential %s for %s", best.id, server_url)
|
|
||||||
return best
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
|
||||||
return None
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
"""Unit tests for the shared MCP helpers."""
|
|
||||||
|
|
||||||
from backend.blocks.mcp.helpers import normalize_mcp_url, parse_mcp_content, server_host
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# normalize_mcp_url
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_trailing_slash():
|
|
||||||
assert normalize_mcp_url("https://mcp.example.com/") == "https://mcp.example.com"
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_whitespace():
|
|
||||||
assert normalize_mcp_url(" https://mcp.example.com ") == "https://mcp.example.com"
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_both():
|
|
||||||
assert (
|
|
||||||
normalize_mcp_url(" https://mcp.example.com/ ") == "https://mcp.example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_noop():
|
|
||||||
assert normalize_mcp_url("https://mcp.example.com") == "https://mcp.example.com"
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_path_with_trailing_slash():
|
|
||||||
assert (
|
|
||||||
normalize_mcp_url("https://mcp.example.com/path/")
|
|
||||||
== "https://mcp.example.com/path"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# server_host
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_server_host_standard_url():
|
|
||||||
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
|
|
||||||
|
|
||||||
|
|
||||||
def test_server_host_strips_credentials():
|
|
||||||
"""hostname must not expose user:pass."""
|
|
||||||
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
|
|
||||||
|
|
||||||
|
|
||||||
def test_server_host_with_port():
|
|
||||||
"""Port should not appear in hostname (hostname strips it)."""
|
|
||||||
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
|
|
||||||
|
|
||||||
|
|
||||||
def test_server_host_fallback():
|
|
||||||
"""Falls back to the raw string for un-parseable URLs."""
|
|
||||||
assert server_host("not-a-url") == "not-a-url"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# parse_mcp_content
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_text_plain():
|
|
||||||
assert parse_mcp_content([{"type": "text", "text": "hello world"}]) == "hello world"
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_text_json():
|
|
||||||
content = [{"type": "text", "text": '{"status": "ok", "count": 42}'}]
|
|
||||||
assert parse_mcp_content(content) == {"status": "ok", "count": 42}
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_image():
|
|
||||||
content = [{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
|
|
||||||
assert parse_mcp_content(content) == {
|
|
||||||
"type": "image",
|
|
||||||
"data": "abc123==",
|
|
||||||
"mimeType": "image/png",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_resource():
|
|
||||||
content = [
|
|
||||||
{"type": "resource", "resource": {"uri": "file:///tmp/out.txt", "text": "hi"}}
|
|
||||||
]
|
|
||||||
assert parse_mcp_content(content) == {"uri": "file:///tmp/out.txt", "text": "hi"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_multi_item():
|
|
||||||
content = [
|
|
||||||
{"type": "text", "text": "first"},
|
|
||||||
{"type": "text", "text": "second"},
|
|
||||||
]
|
|
||||||
assert parse_mcp_content(content) == ["first", "second"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_empty():
|
|
||||||
assert parse_mcp_content([]) is None
|
|
||||||
@@ -4,7 +4,7 @@ from enum import Enum
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from pydantic import SecretStr, field_validator
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
@@ -13,7 +13,6 @@ from backend.blocks._base import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.block import BlockInput
|
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -22,7 +21,6 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.clients import OPENROUTER_BASE_URL
|
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
|
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
||||||
@@ -36,20 +34,6 @@ class PerplexityModel(str, Enum):
|
|||||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_perplexity_model(value: Any) -> PerplexityModel:
|
|
||||||
"""Return a valid PerplexityModel, falling back to SONAR for invalid values."""
|
|
||||||
if isinstance(value, PerplexityModel):
|
|
||||||
return value
|
|
||||||
try:
|
|
||||||
return PerplexityModel(value)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid PerplexityModel '{value}', "
|
|
||||||
f"falling back to {PerplexityModel.SONAR.value}"
|
|
||||||
)
|
|
||||||
return PerplexityModel.SONAR
|
|
||||||
|
|
||||||
|
|
||||||
PerplexityCredentials = CredentialsMetaInput[
|
PerplexityCredentials = CredentialsMetaInput[
|
||||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||||
]
|
]
|
||||||
@@ -88,25 +72,6 @@ class PerplexityBlock(Block):
|
|||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||||
|
|
||||||
@field_validator("model", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def fallback_invalid_model(cls, v: Any) -> PerplexityModel:
|
|
||||||
"""Fall back to SONAR if the model value is not a valid
|
|
||||||
PerplexityModel (e.g. an OpenAI model ID set by the agent
|
|
||||||
generator)."""
|
|
||||||
return _sanitize_perplexity_model(v)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_data(cls, data: BlockInput) -> str | None:
|
|
||||||
"""Sanitize the model field before JSON schema validation so that
|
|
||||||
invalid values are replaced with the default instead of raising a
|
|
||||||
BlockInputError."""
|
|
||||||
model_value = data.get("model")
|
|
||||||
if model_value is not None:
|
|
||||||
data["model"] = _sanitize_perplexity_model(model_value).value
|
|
||||||
return super().validate_data(data)
|
|
||||||
|
|
||||||
system_prompt: str = SchemaField(
|
system_prompt: str = SchemaField(
|
||||||
title="System Prompt",
|
title="System Prompt",
|
||||||
default="",
|
default="",
|
||||||
@@ -171,7 +136,7 @@ class PerplexityBlock(Block):
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Call Perplexity via OpenRouter and extract annotations."""
|
"""Call Perplexity via OpenRouter and extract annotations."""
|
||||||
client = openai.AsyncOpenAI(
|
client = openai.AsyncOpenAI(
|
||||||
base_url=OPENROUTER_BASE_URL,
|
base_url="https://openrouter.ai/api/v1",
|
||||||
api_key=credentials.api_key.get_secret_value(),
|
api_key=credentials.api_key.get_secret_value(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2232,7 +2232,6 @@ class DeleteRedditPostBlock(Block):
|
|||||||
("post_id", "abc123"),
|
("post_id", "abc123"),
|
||||||
],
|
],
|
||||||
test_mock={"delete_post": lambda creds, post_id: True},
|
test_mock={"delete_post": lambda creds, post_id: True},
|
||||||
is_sensitive_action=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -2291,7 +2290,6 @@ class DeleteRedditCommentBlock(Block):
|
|||||||
("comment_id", "xyz789"),
|
("comment_id", "xyz789"),
|
||||||
],
|
],
|
||||||
test_mock={"delete_comment": lambda creds, comment_id: True},
|
test_mock={"delete_comment": lambda creds, comment_id: True},
|
||||||
is_sensitive_action=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
|||||||
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
||||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||||
},
|
},
|
||||||
is_sensitive_action=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
|
|||||||
@@ -83,8 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
|||||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||||
|
|
||||||
# Anthropic
|
# Anthropic
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" # Keep for backwards compat
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_6_SONNET = "claude-sonnet-4-6"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
@@ -138,7 +137,7 @@ class StagehandObserveBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -228,7 +227,7 @@ class StagehandActBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -325,7 +324,7 @@ class StagehandExtractBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_4_6_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
|
||||||
from backend.blocks._base import (
|
from backend.blocks._base import (
|
||||||
Block,
|
Block,
|
||||||
BlockCategory,
|
BlockCategory,
|
||||||
@@ -176,8 +176,8 @@ class SearchStoreAgentsBlock(Block):
|
|||||||
category: str | None = SchemaField(
|
category: str | None = SchemaField(
|
||||||
description="Filter by category", default=None
|
description="Filter by category", default=None
|
||||||
)
|
)
|
||||||
sort_by: StoreAgentsSortOptions = SchemaField(
|
sort_by: Literal["rating", "runs", "name", "updated_at"] = SchemaField(
|
||||||
description="How to sort the results", default=StoreAgentsSortOptions.RATING
|
description="How to sort the results", default="rating"
|
||||||
)
|
)
|
||||||
limit: int = SchemaField(
|
limit: int = SchemaField(
|
||||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||||
@@ -278,7 +278,7 @@ class SearchStoreAgentsBlock(Block):
|
|||||||
self,
|
self,
|
||||||
query: str | None = None,
|
query: str | None = None,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
sort_by: StoreAgentsSortOptions = StoreAgentsSortOptions.RATING,
|
sort_by: Literal["rating", "runs", "name", "updated_at"] = "rating",
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> SearchAgentsResponse:
|
) -> SearchAgentsResponse:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,182 +0,0 @@
|
|||||||
"""
|
|
||||||
Telegram Bot API helper functions.
|
|
||||||
|
|
||||||
Provides utilities for making authenticated requests to the Telegram Bot API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from io import BytesIO
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.model import APIKeyCredentials
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TELEGRAM_API_BASE = "https://api.telegram.org"
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramMessageResult(BaseModel, extra="allow"):
|
|
||||||
"""Result from Telegram send/edit message API calls."""
|
|
||||||
|
|
||||||
message_id: int = 0
|
|
||||||
chat: dict[str, Any] = {}
|
|
||||||
date: int = 0
|
|
||||||
text: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramFileResult(BaseModel, extra="allow"):
|
|
||||||
"""Result from Telegram getFile API call."""
|
|
||||||
|
|
||||||
file_id: str = ""
|
|
||||||
file_unique_id: str = ""
|
|
||||||
file_size: int = 0
|
|
||||||
file_path: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramAPIException(ValueError):
|
|
||||||
"""Exception raised for Telegram API errors."""
|
|
||||||
|
|
||||||
def __init__(self, message: str, error_code: int = 0):
|
|
||||||
super().__init__(message)
|
|
||||||
self.error_code = error_code
|
|
||||||
|
|
||||||
|
|
||||||
def get_bot_api_url(bot_token: str, method: str) -> str:
|
|
||||||
"""Construct Telegram Bot API URL for a method."""
|
|
||||||
return f"{TELEGRAM_API_BASE}/bot{bot_token}/{method}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_url(bot_token: str, file_path: str) -> str:
|
|
||||||
"""Construct Telegram file download URL."""
|
|
||||||
return f"{TELEGRAM_API_BASE}/file/bot{bot_token}/{file_path}"
|
|
||||||
|
|
||||||
|
|
||||||
async def call_telegram_api(
|
|
||||||
credentials: APIKeyCredentials,
|
|
||||||
method: str,
|
|
||||||
data: Optional[dict[str, Any]] = None,
|
|
||||||
) -> TelegramMessageResult:
|
|
||||||
"""
|
|
||||||
Make a request to the Telegram Bot API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: Bot token credentials
|
|
||||||
method: API method name (e.g., "sendMessage", "getFile")
|
|
||||||
data: Request parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API response result
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TelegramAPIException: If the API returns an error
|
|
||||||
"""
|
|
||||||
token = credentials.api_key.get_secret_value()
|
|
||||||
url = get_bot_api_url(token, method)
|
|
||||||
|
|
||||||
response = await Requests().post(url, json=data or {})
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if not result.get("ok"):
|
|
||||||
error_code = result.get("error_code", 0)
|
|
||||||
description = result.get("description", "Unknown error")
|
|
||||||
raise TelegramAPIException(description, error_code)
|
|
||||||
|
|
||||||
return TelegramMessageResult(**result.get("result", {}))
|
|
||||||
|
|
||||||
|
|
||||||
async def call_telegram_api_with_file(
|
|
||||||
credentials: APIKeyCredentials,
|
|
||||||
method: str,
|
|
||||||
file_field: str,
|
|
||||||
file_data: bytes,
|
|
||||||
filename: str,
|
|
||||||
content_type: str,
|
|
||||||
data: Optional[dict[str, Any]] = None,
|
|
||||||
) -> TelegramMessageResult:
|
|
||||||
"""
|
|
||||||
Make a multipart/form-data request to the Telegram Bot API with a file upload.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: Bot token credentials
|
|
||||||
method: API method name (e.g., "sendPhoto", "sendVoice")
|
|
||||||
file_field: Form field name for the file (e.g., "photo", "voice")
|
|
||||||
file_data: Raw file bytes
|
|
||||||
filename: Filename for the upload
|
|
||||||
content_type: MIME type of the file
|
|
||||||
data: Additional form parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API response result
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TelegramAPIException: If the API returns an error
|
|
||||||
"""
|
|
||||||
token = credentials.api_key.get_secret_value()
|
|
||||||
url = get_bot_api_url(token, method)
|
|
||||||
|
|
||||||
files = [(file_field, (filename, BytesIO(file_data), content_type))]
|
|
||||||
|
|
||||||
response = await Requests().post(url, files=files, data=data or {})
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if not result.get("ok"):
|
|
||||||
error_code = result.get("error_code", 0)
|
|
||||||
description = result.get("description", "Unknown error")
|
|
||||||
raise TelegramAPIException(description, error_code)
|
|
||||||
|
|
||||||
return TelegramMessageResult(**result.get("result", {}))
|
|
||||||
|
|
||||||
|
|
||||||
async def get_file_info(
|
|
||||||
credentials: APIKeyCredentials, file_id: str
|
|
||||||
) -> TelegramFileResult:
|
|
||||||
"""
|
|
||||||
Get file information from Telegram.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: Bot token credentials
|
|
||||||
file_id: Telegram file_id from message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File info dict containing file_id, file_unique_id, file_size, file_path
|
|
||||||
"""
|
|
||||||
result = await call_telegram_api(credentials, "getFile", {"file_id": file_id})
|
|
||||||
return TelegramFileResult(**result.model_dump())
|
|
||||||
|
|
||||||
|
|
||||||
async def get_file_download_url(credentials: APIKeyCredentials, file_id: str) -> str:
|
|
||||||
"""
|
|
||||||
Get the download URL for a Telegram file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: Bot token credentials
|
|
||||||
file_id: Telegram file_id from message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Full download URL
|
|
||||||
"""
|
|
||||||
token = credentials.api_key.get_secret_value()
|
|
||||||
result = await get_file_info(credentials, file_id)
|
|
||||||
file_path = result.file_path
|
|
||||||
if not file_path:
|
|
||||||
raise TelegramAPIException("No file_path returned from getFile")
|
|
||||||
return get_file_url(token, file_path)
|
|
||||||
|
|
||||||
|
|
||||||
async def download_telegram_file(credentials: APIKeyCredentials, file_id: str) -> bytes:
|
|
||||||
"""
|
|
||||||
Download a file from Telegram servers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: Bot token credentials
|
|
||||||
file_id: Telegram file_id
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File content as bytes
|
|
||||||
"""
|
|
||||||
url = await get_file_download_url(credentials, file_id)
|
|
||||||
response = await Requests().get(url)
|
|
||||||
return response.content
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
"""
|
|
||||||
Telegram Bot credentials handling.
|
|
||||||
|
|
||||||
Telegram bots use an API key (bot token) obtained from @BotFather.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
# Bot token credentials (API key style)
|
|
||||||
TelegramCredentials = APIKeyCredentials
|
|
||||||
TelegramCredentialsInput = CredentialsMetaInput[
|
|
||||||
Literal[ProviderName.TELEGRAM], Literal["api_key"]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def TelegramCredentialsField() -> TelegramCredentialsInput:
|
|
||||||
"""Creates a Telegram bot token credentials field."""
|
|
||||||
return CredentialsField(
|
|
||||||
description="Telegram Bot API token from @BotFather. "
|
|
||||||
"Create a bot at https://t.me/BotFather to get your token."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Test credentials for unit tests
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
|
||||||
provider="telegram",
|
|
||||||
api_key=SecretStr("test_telegram_bot_token"),
|
|
||||||
title="Mock Telegram Bot Token",
|
|
||||||
expires_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
TEST_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_CREDENTIALS.provider,
|
|
||||||
"id": TEST_CREDENTIALS.id,
|
|
||||||
"type": TEST_CREDENTIALS.type,
|
|
||||||
"title": TEST_CREDENTIALS.title,
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,377 +0,0 @@
|
|||||||
"""
|
|
||||||
Telegram trigger blocks for receiving messages via webhooks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
BlockWebhookConfig,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.integrations.webhooks.telegram import TelegramWebhookType
|
|
||||||
|
|
||||||
from ._auth import (
|
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
TelegramCredentialsField,
|
|
||||||
TelegramCredentialsInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# Example payload for testing
|
|
||||||
EXAMPLE_MESSAGE_PAYLOAD = {
|
|
||||||
"update_id": 123456789,
|
|
||||||
"message": {
|
|
||||||
"message_id": 1,
|
|
||||||
"from": {
|
|
||||||
"id": 12345678,
|
|
||||||
"is_bot": False,
|
|
||||||
"first_name": "John",
|
|
||||||
"last_name": "Doe",
|
|
||||||
"username": "johndoe",
|
|
||||||
"language_code": "en",
|
|
||||||
},
|
|
||||||
"chat": {
|
|
||||||
"id": 12345678,
|
|
||||||
"first_name": "John",
|
|
||||||
"last_name": "Doe",
|
|
||||||
"username": "johndoe",
|
|
||||||
"type": "private",
|
|
||||||
},
|
|
||||||
"date": 1234567890,
|
|
||||||
"text": "Hello, bot!",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramTriggerBase:
|
|
||||||
"""Base class for Telegram trigger blocks."""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
credentials: TelegramCredentialsInput = TelegramCredentialsField()
|
|
||||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramMessageTriggerBlock(TelegramTriggerBase, Block):
|
|
||||||
"""
|
|
||||||
Triggers when a message is received or edited in your Telegram bot.
|
|
||||||
|
|
||||||
Supports text, photos, voice messages, audio files, documents, and videos.
|
|
||||||
Connect the outputs to other blocks to process messages and send responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(TelegramTriggerBase.Input):
|
|
||||||
class EventsFilter(BaseModel):
|
|
||||||
"""Filter for message types to receive."""
|
|
||||||
|
|
||||||
text: bool = True
|
|
||||||
photo: bool = False
|
|
||||||
voice: bool = False
|
|
||||||
audio: bool = False
|
|
||||||
document: bool = False
|
|
||||||
video: bool = False
|
|
||||||
edited_message: bool = False
|
|
||||||
|
|
||||||
events: EventsFilter = SchemaField(
|
|
||||||
title="Message Types", description="Types of messages to receive"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
payload: dict = SchemaField(
|
|
||||||
description="The complete webhook payload from Telegram"
|
|
||||||
)
|
|
||||||
chat_id: int = SchemaField(
|
|
||||||
description="The chat ID where the message was received. "
|
|
||||||
"Use this to send replies."
|
|
||||||
)
|
|
||||||
message_id: int = SchemaField(description="The unique message ID")
|
|
||||||
user_id: int = SchemaField(description="The user ID who sent the message")
|
|
||||||
username: str = SchemaField(description="Username of the sender (may be empty)")
|
|
||||||
first_name: str = SchemaField(description="First name of the sender")
|
|
||||||
event: str = SchemaField(
|
|
||||||
description="The message type (text, photo, voice, audio, etc.)"
|
|
||||||
)
|
|
||||||
text: str = SchemaField(
|
|
||||||
description="Text content of the message (for text messages)"
|
|
||||||
)
|
|
||||||
photo_file_id: str = SchemaField(
|
|
||||||
description="File ID of the photo (for photo messages). "
|
|
||||||
"Use GetTelegramFileBlock to download."
|
|
||||||
)
|
|
||||||
voice_file_id: str = SchemaField(
|
|
||||||
description="File ID of the voice message (for voice messages). "
|
|
||||||
"Use GetTelegramFileBlock to download."
|
|
||||||
)
|
|
||||||
audio_file_id: str = SchemaField(
|
|
||||||
description="File ID of the audio file (for audio messages). "
|
|
||||||
"Use GetTelegramFileBlock to download."
|
|
||||||
)
|
|
||||||
file_id: str = SchemaField(
|
|
||||||
description="File ID for document/video messages. "
|
|
||||||
"Use GetTelegramFileBlock to download."
|
|
||||||
)
|
|
||||||
file_name: str = SchemaField(
|
|
||||||
description="Original filename (for document/audio messages)"
|
|
||||||
)
|
|
||||||
caption: str = SchemaField(description="Caption for media messages")
|
|
||||||
is_edited: bool = SchemaField(
|
|
||||||
description="Whether this is an edit of a previously sent message"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="4435e4e0-df6e-4301-8f35-ad70b12fc9ec",
|
|
||||||
description="Triggers when a message is received or edited in your Telegram bot. "
|
|
||||||
"Supports text, photos, voice messages, audio files, documents, and videos.",
|
|
||||||
categories={BlockCategory.SOCIAL},
|
|
||||||
input_schema=TelegramMessageTriggerBlock.Input,
|
|
||||||
output_schema=TelegramMessageTriggerBlock.Output,
|
|
||||||
webhook_config=BlockWebhookConfig(
|
|
||||||
provider=ProviderName.TELEGRAM,
|
|
||||||
webhook_type=TelegramWebhookType.BOT,
|
|
||||||
resource_format="bot",
|
|
||||||
event_filter_input="events",
|
|
||||||
event_format="message.{event}",
|
|
||||||
),
|
|
||||||
test_input={
|
|
||||||
"events": {"text": True, "photo": True},
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
"payload": EXAMPLE_MESSAGE_PAYLOAD,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("payload", EXAMPLE_MESSAGE_PAYLOAD),
|
|
||||||
("chat_id", 12345678),
|
|
||||||
("message_id", 1),
|
|
||||||
("user_id", 12345678),
|
|
||||||
("username", "johndoe"),
|
|
||||||
("first_name", "John"),
|
|
||||||
("is_edited", False),
|
|
||||||
("event", "text"),
|
|
||||||
("text", "Hello, bot!"),
|
|
||||||
("photo_file_id", ""),
|
|
||||||
("voice_file_id", ""),
|
|
||||||
("audio_file_id", ""),
|
|
||||||
("file_id", ""),
|
|
||||||
("file_name", ""),
|
|
||||||
("caption", ""),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
|
||||||
payload = input_data.payload
|
|
||||||
is_edited = "edited_message" in payload
|
|
||||||
message = payload.get("message") or payload.get("edited_message", {})
|
|
||||||
|
|
||||||
# Extract common fields
|
|
||||||
chat = message.get("chat", {})
|
|
||||||
sender = message.get("from", {})
|
|
||||||
|
|
||||||
yield "payload", payload
|
|
||||||
yield "chat_id", chat.get("id", 0)
|
|
||||||
yield "message_id", message.get("message_id", 0)
|
|
||||||
yield "user_id", sender.get("id", 0)
|
|
||||||
yield "username", sender.get("username", "")
|
|
||||||
yield "first_name", sender.get("first_name", "")
|
|
||||||
yield "is_edited", is_edited
|
|
||||||
|
|
||||||
# For edited messages, yield event as "edited_message" and extract
|
|
||||||
# all content fields from the edited message body
|
|
||||||
if is_edited:
|
|
||||||
yield "event", "edited_message"
|
|
||||||
yield "text", message.get("text", "")
|
|
||||||
photos = message.get("photo", [])
|
|
||||||
yield "photo_file_id", photos[-1].get("file_id", "") if photos else ""
|
|
||||||
voice = message.get("voice", {})
|
|
||||||
yield "voice_file_id", voice.get("file_id", "")
|
|
||||||
audio = message.get("audio", {})
|
|
||||||
yield "audio_file_id", audio.get("file_id", "")
|
|
||||||
document = message.get("document", {})
|
|
||||||
video = message.get("video", {})
|
|
||||||
yield "file_id", (document.get("file_id", "") or video.get("file_id", ""))
|
|
||||||
yield "file_name", (
|
|
||||||
document.get("file_name", "") or audio.get("file_name", "")
|
|
||||||
)
|
|
||||||
yield "caption", message.get("caption", "")
|
|
||||||
# Determine message type and extract content
|
|
||||||
elif "text" in message:
|
|
||||||
yield "event", "text"
|
|
||||||
yield "text", message.get("text", "")
|
|
||||||
yield "photo_file_id", ""
|
|
||||||
yield "voice_file_id", ""
|
|
||||||
yield "audio_file_id", ""
|
|
||||||
yield "file_id", ""
|
|
||||||
yield "file_name", ""
|
|
||||||
yield "caption", ""
|
|
||||||
elif "photo" in message:
|
|
||||||
# Get the largest photo (last in array)
|
|
||||||
photos = message.get("photo", [])
|
|
||||||
photo_fid = photos[-1].get("file_id", "") if photos else ""
|
|
||||||
yield "event", "photo"
|
|
||||||
yield "text", ""
|
|
||||||
yield "photo_file_id", photo_fid
|
|
||||||
yield "voice_file_id", ""
|
|
||||||
yield "audio_file_id", ""
|
|
||||||
yield "file_id", ""
|
|
||||||
yield "file_name", ""
|
|
||||||
yield "caption", message.get("caption", "")
|
|
||||||
elif "voice" in message:
|
|
||||||
voice = message.get("voice", {})
|
|
||||||
yield "event", "voice"
|
|
||||||
yield "text", ""
|
|
||||||
yield "photo_file_id", ""
|
|
||||||
yield "voice_file_id", voice.get("file_id", "")
|
|
||||||
yield "audio_file_id", ""
|
|
||||||
yield "file_id", ""
|
|
||||||
yield "file_name", ""
|
|
||||||
yield "caption", message.get("caption", "")
|
|
||||||
elif "audio" in message:
|
|
||||||
audio = message.get("audio", {})
|
|
||||||
yield "event", "audio"
|
|
||||||
yield "text", ""
|
|
||||||
yield "photo_file_id", ""
|
|
||||||
yield "voice_file_id", ""
|
|
||||||
yield "audio_file_id", audio.get("file_id", "")
|
|
||||||
yield "file_id", ""
|
|
||||||
yield "file_name", audio.get("file_name", "")
|
|
||||||
yield "caption", message.get("caption", "")
|
|
||||||
elif "document" in message:
|
|
||||||
document = message.get("document", {})
|
|
||||||
yield "event", "document"
|
|
||||||
yield "text", ""
|
|
||||||
yield "photo_file_id", ""
|
|
||||||
yield "voice_file_id", ""
|
|
||||||
yield "audio_file_id", ""
|
|
||||||
yield "file_id", document.get("file_id", "")
|
|
||||||
yield "file_name", document.get("file_name", "")
|
|
||||||
yield "caption", message.get("caption", "")
|
|
||||||
elif "video" in message:
|
|
||||||
video = message.get("video", {})
|
|
||||||
yield "event", "video"
|
|
||||||
yield "text", ""
|
|
||||||
yield "photo_file_id", ""
|
|
||||||
yield "voice_file_id", ""
|
|
||||||
yield "audio_file_id", ""
|
|
||||||
yield "file_id", video.get("file_id", "")
|
|
||||||
yield "file_name", video.get("file_name", "")
|
|
||||||
yield "caption", message.get("caption", "")
|
|
||||||
else:
|
|
||||||
yield "event", "other"
|
|
||||||
yield "text", ""
|
|
||||||
yield "photo_file_id", ""
|
|
||||||
yield "voice_file_id", ""
|
|
||||||
yield "audio_file_id", ""
|
|
||||||
yield "file_id", ""
|
|
||||||
yield "file_name", ""
|
|
||||||
yield "caption", ""
|
|
||||||
|
|
||||||
|
|
||||||
# Example payload for reaction trigger testing
|
|
||||||
EXAMPLE_REACTION_PAYLOAD = {
|
|
||||||
"update_id": 123456790,
|
|
||||||
"message_reaction": {
|
|
||||||
"chat": {
|
|
||||||
"id": 12345678,
|
|
||||||
"first_name": "John",
|
|
||||||
"last_name": "Doe",
|
|
||||||
"username": "johndoe",
|
|
||||||
"type": "private",
|
|
||||||
},
|
|
||||||
"message_id": 42,
|
|
||||||
"user": {
|
|
||||||
"id": 12345678,
|
|
||||||
"is_bot": False,
|
|
||||||
"first_name": "John",
|
|
||||||
"username": "johndoe",
|
|
||||||
},
|
|
||||||
"date": 1234567890,
|
|
||||||
"new_reaction": [{"type": "emoji", "emoji": "👍"}],
|
|
||||||
"old_reaction": [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramMessageReactionTriggerBlock(TelegramTriggerBase, Block):
|
|
||||||
"""
|
|
||||||
Triggers when a reaction to a message is changed.
|
|
||||||
|
|
||||||
Works automatically in private chats. In group chats, the bot must be
|
|
||||||
an administrator to receive reaction updates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(TelegramTriggerBase.Input):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
payload: dict = SchemaField(
|
|
||||||
description="The complete webhook payload from Telegram"
|
|
||||||
)
|
|
||||||
chat_id: int = SchemaField(
|
|
||||||
description="The chat ID where the reaction occurred"
|
|
||||||
)
|
|
||||||
message_id: int = SchemaField(description="The message ID that was reacted to")
|
|
||||||
user_id: int = SchemaField(description="The user ID who changed the reaction")
|
|
||||||
username: str = SchemaField(description="Username of the user (may be empty)")
|
|
||||||
new_reactions: list = SchemaField(
|
|
||||||
description="List of new reactions on the message"
|
|
||||||
)
|
|
||||||
old_reactions: list = SchemaField(
|
|
||||||
description="List of previous reactions on the message"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="82525328-9368-4966-8f0c-cd78e80181fd",
|
|
||||||
description="Triggers when a reaction to a message is changed. "
|
|
||||||
"Works in private chats automatically. "
|
|
||||||
"In groups, the bot must be an administrator.",
|
|
||||||
categories={BlockCategory.SOCIAL},
|
|
||||||
input_schema=TelegramMessageReactionTriggerBlock.Input,
|
|
||||||
output_schema=TelegramMessageReactionTriggerBlock.Output,
|
|
||||||
webhook_config=BlockWebhookConfig(
|
|
||||||
provider=ProviderName.TELEGRAM,
|
|
||||||
webhook_type=TelegramWebhookType.BOT,
|
|
||||||
resource_format="bot",
|
|
||||||
event_filter_input="",
|
|
||||||
event_format="message_reaction",
|
|
||||||
),
|
|
||||||
test_input={
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
"payload": EXAMPLE_REACTION_PAYLOAD,
|
|
||||||
},
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_output=[
|
|
||||||
("payload", EXAMPLE_REACTION_PAYLOAD),
|
|
||||||
("chat_id", 12345678),
|
|
||||||
("message_id", 42),
|
|
||||||
("user_id", 12345678),
|
|
||||||
("username", "johndoe"),
|
|
||||||
("new_reactions", [{"type": "emoji", "emoji": "👍"}]),
|
|
||||||
("old_reactions", []),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
|
||||||
payload = input_data.payload
|
|
||||||
reaction = payload.get("message_reaction", {})
|
|
||||||
|
|
||||||
chat = reaction.get("chat", {})
|
|
||||||
user = reaction.get("user", {})
|
|
||||||
|
|
||||||
yield "payload", payload
|
|
||||||
yield "chat_id", chat.get("id", 0)
|
|
||||||
yield "message_id", reaction.get("message_id", 0)
|
|
||||||
yield "user_id", user.get("id", 0)
|
|
||||||
yield "username", user.get("username", "")
|
|
||||||
yield "new_reactions", reaction.get("new_reaction", [])
|
|
||||||
yield "old_reactions", reaction.get("old_reaction", [])
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
"""Unit tests for PerplexityBlock model fallback behavior."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.perplexity import (
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
PerplexityBlock,
|
|
||||||
PerplexityModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_input(**overrides) -> dict:
|
|
||||||
defaults = {
|
|
||||||
"prompt": "test query",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
}
|
|
||||||
defaults.update(overrides)
|
|
||||||
return defaults
|
|
||||||
|
|
||||||
|
|
||||||
class TestPerplexityModelFallback:
|
|
||||||
"""Tests for fallback_invalid_model field_validator."""
|
|
||||||
|
|
||||||
def test_invalid_model_falls_back_to_sonar(self):
|
|
||||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-5.2-2025-12-11"))
|
|
||||||
assert inp.model == PerplexityModel.SONAR
|
|
||||||
|
|
||||||
def test_another_invalid_model_falls_back_to_sonar(self):
|
|
||||||
inp = PerplexityBlock.Input(**_make_input(model="gpt-4o"))
|
|
||||||
assert inp.model == PerplexityModel.SONAR
|
|
||||||
|
|
||||||
def test_valid_model_string_is_kept(self):
|
|
||||||
inp = PerplexityBlock.Input(**_make_input(model="perplexity/sonar-pro"))
|
|
||||||
assert inp.model == PerplexityModel.SONAR_PRO
|
|
||||||
|
|
||||||
def test_valid_enum_value_is_kept(self):
|
|
||||||
inp = PerplexityBlock.Input(
|
|
||||||
**_make_input(model=PerplexityModel.SONAR_DEEP_RESEARCH)
|
|
||||||
)
|
|
||||||
assert inp.model == PerplexityModel.SONAR_DEEP_RESEARCH
|
|
||||||
|
|
||||||
def test_default_model_when_omitted(self):
|
|
||||||
inp = PerplexityBlock.Input(**_make_input())
|
|
||||||
assert inp.model == PerplexityModel.SONAR
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_value",
|
|
||||||
[
|
|
||||||
"perplexity/sonar",
|
|
||||||
"perplexity/sonar-pro",
|
|
||||||
"perplexity/sonar-deep-research",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_all_valid_models_accepted(self, model_value: str):
|
|
||||||
inp = PerplexityBlock.Input(**_make_input(model=model_value))
|
|
||||||
assert inp.model.value == model_value
|
|
||||||
|
|
||||||
|
|
||||||
class TestPerplexityValidateData:
|
|
||||||
"""Tests for validate_data which runs during block execution (before
|
|
||||||
Pydantic instantiation). Invalid models must be sanitized here so
|
|
||||||
JSON schema validation does not reject them."""
|
|
||||||
|
|
||||||
def test_invalid_model_sanitized_before_schema_validation(self):
|
|
||||||
data = _make_input(model="gpt-5.2-2025-12-11")
|
|
||||||
error = PerplexityBlock.Input.validate_data(data)
|
|
||||||
assert error is None
|
|
||||||
assert data["model"] == PerplexityModel.SONAR.value
|
|
||||||
|
|
||||||
def test_valid_model_unchanged_by_validate_data(self):
|
|
||||||
data = _make_input(model="perplexity/sonar-pro")
|
|
||||||
error = PerplexityBlock.Input.validate_data(data)
|
|
||||||
assert error is None
|
|
||||||
assert data["model"] == "perplexity/sonar-pro"
|
|
||||||
|
|
||||||
def test_missing_model_uses_default(self):
|
|
||||||
data = _make_input() # no model key
|
|
||||||
error = PerplexityBlock.Input.validate_data(data)
|
|
||||||
assert error is None
|
|
||||||
inp = PerplexityBlock.Input(**data)
|
|
||||||
assert inp.model == PerplexityModel.SONAR
|
|
||||||
@@ -2,7 +2,6 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.api.features.store.db import StoreAgentsSortOptions
|
|
||||||
from backend.blocks.system.library_operations import (
|
from backend.blocks.system.library_operations import (
|
||||||
AddToLibraryFromStoreBlock,
|
AddToLibraryFromStoreBlock,
|
||||||
LibraryAgent,
|
LibraryAgent,
|
||||||
@@ -122,10 +121,7 @@ async def test_search_store_agents_block(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
input_data = block.Input(
|
input_data = block.Input(
|
||||||
query="test",
|
query="test", category="productivity", sort_by="rating", limit=10
|
||||||
category="productivity",
|
|
||||||
sort_by=StoreAgentsSortOptions.RATING, # type: ignore[reportArgumentType]
|
|
||||||
limit=10,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
|||||||
@@ -34,12 +34,10 @@ def main(output: Path, pretty: bool):
|
|||||||
"""Generate and output the OpenAPI JSON specification."""
|
"""Generate and output the OpenAPI JSON specification."""
|
||||||
openapi_schema = get_openapi_schema()
|
openapi_schema = get_openapi_schema()
|
||||||
|
|
||||||
json_output = json.dumps(
|
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
|
||||||
openapi_schema, indent=2 if pretty else None, ensure_ascii=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if output:
|
if output:
|
||||||
output.write_text(json_output, encoding="utf-8")
|
output.write_text(json_output)
|
||||||
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
|
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
|
||||||
click.echo(f"\n{json_output[:500]} ...")
|
click.echo(f"\n{json_output[:500]} ...")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
from .service import stream_chat_completion_baseline
|
|
||||||
|
|
||||||
__all__ = ["stream_chat_completion_baseline"]
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user