mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
104 Commits
autogpt-rs
...
update-ins
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82bddd885b | ||
|
|
c6247f265e | ||
|
|
38610d1e7a | ||
|
|
ebfbf31c73 | ||
|
|
4abe37396c | ||
|
|
fa14bf461b | ||
|
|
e2c33e3d2a | ||
|
|
650be0d1f7 | ||
|
|
35bd7f7f7a | ||
|
|
312cb0227f | ||
|
|
a8feb3c8d0 | ||
|
|
5da5c2ecd6 | ||
|
|
ba65fee862 | ||
|
|
908dcd7b4b | ||
|
|
542f951dd8 | ||
|
|
72938590f2 | ||
|
|
5d364e13f6 | ||
|
|
32513b26ab | ||
|
|
bf92e7dbc8 | ||
|
|
6fce3a09ea | ||
|
|
9158d4b6a2 | ||
|
|
2403931c2e | ||
|
|
af58b316a2 | ||
|
|
03e3e2ea9a | ||
|
|
6bb6a081a2 | ||
|
|
df20b70f44 | ||
|
|
21faf1b677 | ||
|
|
b53c373a59 | ||
|
|
4bfeddc03d | ||
|
|
af7d56612d | ||
|
|
0dd30e275c | ||
|
|
c71406af8b | ||
|
|
a135f09336 | ||
|
|
2d436caa84 | ||
|
|
34dd218a91 | ||
|
|
41f500790f | ||
|
|
793de77e76 | ||
|
|
a2059c6023 | ||
|
|
b9c3920227 | ||
|
|
abba10b649 | ||
|
|
6c34790b42 | ||
|
|
c168277b1d | ||
|
|
89eb5d1189 | ||
|
|
e13e0d4376 | ||
|
|
f4a732373b | ||
|
|
28d85ad61c | ||
|
|
d4b5508ed1 | ||
|
|
0116866199 | ||
|
|
b68e490868 | ||
|
|
c1c5571fd5 | ||
|
|
da16397882 | ||
|
|
098c12a961 | ||
|
|
a28b2cf04f | ||
|
|
468d1af802 | ||
|
|
de7b6b503f | ||
|
|
5338ab5b80 | ||
|
|
e8f897ead1 | ||
|
|
fbe432919d | ||
|
|
4f208d262e | ||
|
|
ac9265c40d | ||
|
|
e60deba05f | ||
|
|
3131e2e856 | ||
|
|
378d256b58 | ||
|
|
3c52b75278 | ||
|
|
40601f1616 | ||
|
|
178c91d6b9 | ||
|
|
c972f34713 | ||
|
|
7b3ee66247 | ||
|
|
2d10ac92b5 | ||
|
|
377b5ef01c | ||
|
|
a2c88c7786 | ||
|
|
7922e4add4 | ||
|
|
e79b7a95dc | ||
|
|
f172b314a4 | ||
|
|
a21711a7ff | ||
|
|
e2af2f454d | ||
|
|
59cc3266e0 | ||
|
|
c9360555b2 | ||
|
|
4a63fbc006 | ||
|
|
9848266474 | ||
|
|
3fe88b6106 | ||
|
|
fa2d968458 | ||
|
|
b935638240 | ||
|
|
f9b255fb7a | ||
|
|
cc6697e46d | ||
|
|
5a5f7f0f9e | ||
|
|
05d4d21d98 | ||
|
|
1e6bd8d2a6 | ||
|
|
6f8d0bfdf2 | ||
|
|
b85e8204df | ||
|
|
e5d3ebac08 | ||
|
|
2182c7ba9e | ||
|
|
e043e4989b | ||
|
|
5dbc3a7d39 | ||
|
|
0978f406bc | ||
|
|
1c3fa804d4 | ||
|
|
69d873debc | ||
|
|
4283798dc2 | ||
|
|
326c4a9e0c | ||
|
|
7705cf243c | ||
|
|
8331dabf6a | ||
|
|
e632549175 | ||
|
|
878f61aaf4 | ||
|
|
9ce857e29f |
@@ -15,6 +15,7 @@
|
||||
!autogpt_platform/backend/pyproject.toml
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
@@ -27,6 +28,7 @@
|
||||
# Platform - Frontend
|
||||
!autogpt_platform/frontend/src/
|
||||
!autogpt_platform/frontend/public/
|
||||
!autogpt_platform/frontend/scripts/
|
||||
!autogpt_platform/frontend/package.json
|
||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||
!autogpt_platform/frontend/tsconfig.json
|
||||
@@ -34,6 +36,7 @@
|
||||
## config
|
||||
!autogpt_platform/frontend/*.config.*
|
||||
!autogpt_platform/frontend/.env.*
|
||||
!autogpt_platform/frontend/.env
|
||||
|
||||
# Classic - AutoGPT
|
||||
!classic/original_autogpt/autogpt/
|
||||
|
||||
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -24,7 +24,8 @@
|
||||
</details>
|
||||
|
||||
#### For configuration changes:
|
||||
- [ ] `.env.example` is updated or already compatible with my changes
|
||||
|
||||
- [ ] `.env.default` is updated or already compatible with my changes
|
||||
- [ ] `docker-compose.yml` is updated or already compatible with my changes
|
||||
- [ ] I have included a list of my configuration changes in the PR description (under **Changes**)
|
||||
|
||||
|
||||
46
.github/workflows/platform-frontend-ci.yml
vendored
46
.github/workflows/platform-frontend-ci.yml
vendored
@@ -82,37 +82,6 @@ jobs:
|
||||
- name: Run lint
|
||||
run: pnpm lint
|
||||
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tsc check
|
||||
run: pnpm type-check
|
||||
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
@@ -176,11 +145,7 @@ jobs:
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.example ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.example ../backend/.env
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -252,15 +217,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.example .env
|
||||
|
||||
- name: Build frontend
|
||||
run: pnpm build --turbo
|
||||
# uses Turbopack, much faster and safe enough for a test pipeline
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Install Browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
|
||||
132
.github/workflows/platform-fullstack-ci.yml
vendored
Normal file
132
.github/workflows/platform-fullstack-ci.yml
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -5,6 +5,8 @@ classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
/.env
|
||||
azure.yaml
|
||||
.vscode
|
||||
.idea/*
|
||||
@@ -121,7 +123,6 @@ celerybeat.pid
|
||||
|
||||
# Environments
|
||||
.direnv/
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv*/
|
||||
|
||||
@@ -235,7 +235,7 @@ repos:
|
||||
hooks:
|
||||
- id: tsc
|
||||
name: Typecheck - AutoGPT Platform - Frontend
|
||||
entry: bash -c 'cd autogpt_platform/frontend && pnpm type-check'
|
||||
entry: bash -c 'cd autogpt_platform/frontend && pnpm types'
|
||||
files: ^autogpt_platform/frontend/
|
||||
types: [file]
|
||||
language: system
|
||||
|
||||
10
README.md
10
README.md
@@ -3,6 +3,16 @@
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
|
||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
[Deutsch](https://zdoc.app/de/Significant-Gravitas/AutoGPT) |
|
||||
[Español](https://zdoc.app/es/Significant-Gravitas/AutoGPT) |
|
||||
[français](https://zdoc.app/fr/Significant-Gravitas/AutoGPT) |
|
||||
[日本語](https://zdoc.app/ja/Significant-Gravitas/AutoGPT) |
|
||||
[한국어](https://zdoc.app/ko/Significant-Gravitas/AutoGPT) |
|
||||
[Português](https://zdoc.app/pt/Significant-Gravitas/AutoGPT) |
|
||||
[Русский](https://zdoc.app/ru/Significant-Gravitas/AutoGPT) |
|
||||
[中文](https://zdoc.app/zh/Significant-Gravitas/AutoGPT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
## Hosting Options
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`/frontend`): Next.js React application
|
||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
||||
@@ -11,6 +13,7 @@ AutoGPT Platform is a monorepo containing:
|
||||
## Essential Commands
|
||||
|
||||
### Backend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd backend && poetry install
|
||||
@@ -30,11 +33,18 @@ poetry run test
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in TESTING.md
|
||||
|
||||
#### Creating/Updating Snapshots
|
||||
@@ -47,8 +57,8 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
|
||||
### Frontend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && npm install
|
||||
@@ -66,12 +76,13 @@ npm run storybook
|
||||
npm run build
|
||||
|
||||
# Type checking
|
||||
npm run type-check
|
||||
npm run types
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
@@ -80,6 +91,7 @@ npm run type-check
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
@@ -87,6 +99,7 @@ npm run type-check
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
|
||||
### Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
@@ -94,13 +107,16 @@ npm run type-check
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Frontend uses Playwright for E2E tests
|
||||
- Component testing via Storybook
|
||||
|
||||
### Database Schema
|
||||
|
||||
Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
@@ -108,13 +124,31 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
### Environment Configuration
|
||||
- Backend: `.env` file in `/backend`
|
||||
- Frontend: `.env.local` file in `/frontend`
|
||||
- Both require Supabase credentials and API keys for various services
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Common Development Tasks
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class
|
||||
3. Define input/output schemas
|
||||
@@ -122,13 +156,18 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
@@ -137,6 +176,7 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
@@ -144,3 +184,47 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR aginst the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
||||
- Use conventional commit messages (see below)/
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
|
||||
@@ -8,7 +8,6 @@ Welcome to the AutoGPT Platform - a powerful system for creating and running AI
|
||||
|
||||
- Docker
|
||||
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
|
||||
- Node.js & NPM (for running the frontend application)
|
||||
|
||||
### Running the System
|
||||
|
||||
@@ -24,10 +23,10 @@ To run the AutoGPT Platform, follow these steps:
|
||||
2. Run the following command:
|
||||
|
||||
```
|
||||
cp .env.example .env
|
||||
cp .env.default .env
|
||||
```
|
||||
|
||||
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
This command will copy the `.env.default` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
|
||||
3. Run the following command:
|
||||
|
||||
@@ -37,44 +36,7 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
|
||||
|
||||
4. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
|
||||
```
|
||||
cd frontend
|
||||
```
|
||||
|
||||
You will need to run your frontend application separately on your local machine.
|
||||
|
||||
5. Run the following command:
|
||||
|
||||
```
|
||||
cp .env.example .env.local
|
||||
```
|
||||
|
||||
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
|
||||
|
||||
6. Run the following command:
|
||||
|
||||
Enable corepack and install dependencies by running:
|
||||
|
||||
```
|
||||
corepack enable
|
||||
pnpm i
|
||||
```
|
||||
|
||||
Generate the API client (this step is required before running the frontend):
|
||||
|
||||
```
|
||||
pnpm generate:api-client
|
||||
```
|
||||
|
||||
Then start the frontend application in development mode:
|
||||
|
||||
```
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
@@ -177,20 +139,21 @@ The platform includes scripts for generating and managing the API client:
|
||||
|
||||
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
|
||||
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
|
||||
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
|
||||
- `pnpm generate:api`: Runs both fetch and generate commands in sequence
|
||||
|
||||
#### Manual API Client Updates
|
||||
|
||||
If you need to update the API client after making changes to the backend API:
|
||||
|
||||
1. Ensure the backend services are running:
|
||||
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
2. Generate the updated API client:
|
||||
```
|
||||
pnpm generate:api-all
|
||||
pnpm generate:api
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
@@ -7,9 +7,5 @@ class Settings:
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.JWT_SECRET_KEY)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
||||
|
||||
import ldclient
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .config import SETTINGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_client() -> LDClient:
|
||||
"""Get the LaunchDarkly client singleton."""
|
||||
return ldclient.get()
|
||||
|
||||
|
||||
def initialize_launchdarkly() -> None:
|
||||
sdk_key = SETTINGS.launch_darkly_sdk_key
|
||||
logger.debug(
|
||||
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
|
||||
)
|
||||
|
||||
if not sdk_key:
|
||||
logger.warning("LaunchDarkly SDK key not configured")
|
||||
return
|
||||
|
||||
config = Config(sdk_key)
|
||||
ldclient.set_config(config)
|
||||
|
||||
if ldclient.get().is_initialized():
|
||||
logger.info("LaunchDarkly client initialized successfully")
|
||||
else:
|
||||
logger.error("LaunchDarkly client failed to initialize")
|
||||
|
||||
|
||||
def shutdown_launchdarkly() -> None:
|
||||
"""Shutdown the LaunchDarkly client."""
|
||||
if ldclient.get().is_initialized():
|
||||
ldclient.get().close()
|
||||
logger.info("LaunchDarkly client closed successfully")
|
||||
|
||||
|
||||
def create_context(
|
||||
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
|
||||
) -> Context:
|
||||
"""Create LaunchDarkly context with optional additional attributes."""
|
||||
builder = Context.builder(str(user_id)).kind("user")
|
||||
if additional_attributes:
|
||||
for key, value in additional_attributes.items():
|
||||
builder.set(key, value)
|
||||
return builder.build()
|
||||
|
||||
|
||||
def feature_flag(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""
|
||||
Decorator for feature flag protected endpoints.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[P, Union[T, Awaitable[T]]],
|
||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return cast(T, result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
return cast(T, func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
return cast(
|
||||
Callable[P, Union[T, Awaitable[T]]],
|
||||
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def percentage_rollout(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for percentage-based rollouts."""
|
||||
return feature_flag(flag_key, default)
|
||||
|
||||
|
||||
def beta_feature(
|
||||
flag_key: Optional[str] = None,
|
||||
unauthorized_response: Any = {"message": "Not available in beta"},
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for beta features."""
|
||||
actual_key = f"beta-{flag_key}" if flag_key else "beta"
|
||||
return feature_flag(actual_key, False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||
"""Context manager for testing feature flags."""
|
||||
original_variation = get_client().variation
|
||||
get_client().variation = lambda key, context, default: (
|
||||
return_value if key == flag_key else original_variation(key, context, default)
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
get_client().variation = original_variation
|
||||
@@ -1,45 +0,0 @@
|
||||
import pytest
|
||||
from ldclient import LDClient
|
||||
|
||||
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ld_client(mocker):
|
||||
client = mocker.Mock(spec=LDClient)
|
||||
mocker.patch("ldclient.get", return_value=client)
|
||||
client.is_initialized.return_value = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_enabled(ld_client):
|
||||
ld_client.variation.return_value = True
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == "success"
|
||||
ld_client.variation.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_unauthorized_response(ld_client):
|
||||
ld_client.variation.return_value = False
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == {"error": "disabled"}
|
||||
|
||||
|
||||
def test_mock_flag_variation(ld_client):
|
||||
with mock_flag_variation("test-flag", True):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
|
||||
with mock_flag_variation("test-flag", False):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
@@ -1,15 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
launch_darkly_sdk_key: str = Field(
|
||||
default="",
|
||||
description="The Launch Darkly SDK key",
|
||||
validation_alias="LAUNCH_DARKLY_SDK_KEY",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Logging module for Auto-GPT."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -10,6 +12,15 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import AGPTFormatter
|
||||
|
||||
# Configure global socket timeout and gRPC keepalive to prevent deadlocks
|
||||
# This must be done at import time before any gRPC connections are established
|
||||
socket.setdefaulttimeout(30) # 30-second socket timeout
|
||||
|
||||
# Enable gRPC keepalive to detect dead connections faster
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIME_MS", "30000") # 30 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIMEOUT_MS", "5000") # 5 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", "true")
|
||||
|
||||
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
|
||||
LOG_FILE = "activity.log"
|
||||
DEBUG_LOG_FILE = "debug.log"
|
||||
@@ -79,7 +90,6 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
Note: This function is typically called at the start of the application
|
||||
to set up the logging infrastructure.
|
||||
"""
|
||||
|
||||
config = LoggingConfig()
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
@@ -105,13 +115,17 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports.sync import SyncTransport
|
||||
from google.cloud.logging_v2.handlers.transports import (
|
||||
BackgroundThreadTransport,
|
||||
)
|
||||
|
||||
client = google.cloud.logging.Client()
|
||||
# Use BackgroundThreadTransport to prevent blocking the main thread
|
||||
# and deadlocks when gRPC calls to Google Cloud Logging hang
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=SyncTransport,
|
||||
transport=BackgroundThreadTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
log_handlers.append(cloud_handler)
|
||||
|
||||
@@ -1,39 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import uvicorn.config
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
|
||||
|
||||
|
||||
def fmt_kwargs(kwargs: dict) -> str:
|
||||
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
|
||||
|
||||
|
||||
def print_attribute(
|
||||
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
|
||||
) -> None:
|
||||
logger = logging.getLogger()
|
||||
logger.info(
|
||||
str(value),
|
||||
extra={
|
||||
"title": f"{title.rstrip(':')}:",
|
||||
"title_color": title_color,
|
||||
"color": value_color,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def generate_uvicorn_config():
|
||||
"""
|
||||
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
|
||||
"""
|
||||
log_config = dict(uvicorn.config.LOGGING_CONFIG)
|
||||
log_config["loggers"]["uvicorn"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
|
||||
return log_config
|
||||
|
||||
@@ -1,17 +1,34 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
def thread_cached(
|
||||
@@ -57,3 +74,193 @@ def thread_cached(
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT")
|
||||
|
||||
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, cache_storage[key])
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, result)
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction[P, R], wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
pass
|
||||
|
||||
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]] | None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> (
|
||||
AsyncCachedFunction[P, R]
|
||||
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
|
||||
):
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
func: The async function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
# Without parentheses (uses default maxsize=128)
|
||||
@async_cache
|
||||
async def get_data(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# With parentheses and custom maxsize
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
if func is None:
|
||||
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
else:
|
||||
# Called without parentheses @async_cache
|
||||
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
return decorator(func)
|
||||
|
||||
@@ -16,7 +16,12 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -323,3 +328,378 @@ class TestThreadCached:
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
|
||||
52
autogpt_platform/backend/.dockerignore
Normal file
52
autogpt_platform/backend/.dockerignore
Normal file
@@ -0,0 +1,52 @@
|
||||
# Development and testing files
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/*.pyd
|
||||
**/.Python
|
||||
**/env/
|
||||
**/venv/
|
||||
**/.venv/
|
||||
**/pip-log.txt
|
||||
**/.pytest_cache/
|
||||
**/test-results/
|
||||
**/snapshots/
|
||||
**/test/
|
||||
|
||||
# IDE and editor files
|
||||
**/.vscode/
|
||||
**/.idea/
|
||||
**/*.swp
|
||||
**/*.swo
|
||||
*~
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
**/*.log
|
||||
**/logs/
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
||||
# Documentation
|
||||
**/*.md
|
||||
!README.md
|
||||
|
||||
# Local development files
|
||||
.env
|
||||
.env.local
|
||||
**/.env.test
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
**/target/
|
||||
|
||||
# Docker files (avoid recursion)
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
@@ -1,3 +1,9 @@
|
||||
# Backend Configuration
|
||||
# This file contains environment variables that MUST be set for the AutoGPT platform
|
||||
# Variables with working defaults in settings.py are not included here
|
||||
|
||||
## ===== REQUIRED DATABASE CONFIGURATION ===== ##
|
||||
# PostgreSQL Database Connection
|
||||
DB_USER=postgres
|
||||
DB_PASS=your-super-secret-and-long-postgres-password
|
||||
DB_NAME=postgres
|
||||
@@ -10,72 +16,50 @@ DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
ENABLE_AUTH=true
|
||||
|
||||
# EXECUTOR
|
||||
NUM_GRAPH_WORKERS=10
|
||||
|
||||
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
|
||||
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
|
||||
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
|
||||
|
||||
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
ENABLE_CREDIT=false
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
# What environment things should be logged under: local dev or prod
|
||||
APP_ENV=local
|
||||
# What environment to behave as: "local" or "cloud"
|
||||
BEHAVE_AS=local
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
# Email For Postmark so we can send emails
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
|
||||
ENABLE_AUTH=true
|
||||
# Supabase Authentication
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
# RabbitMQ credentials -- Used for communication between services
|
||||
RABBITMQ_HOST=localhost
|
||||
RABBITMQ_PORT=5672
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## GCS bucket is required for marketplace and library functionality
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
# Media Storage (required for marketplace and library functionality)
|
||||
MEDIA_GCS_BUCKET_NAME=
|
||||
|
||||
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
|
||||
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
## ===== API KEYS AND OAUTH CREDENTIALS ===== ##
|
||||
# All API keys below are optional - only add what you need
|
||||
|
||||
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
|
||||
## to use the platform's webhook-related functionality.
|
||||
## If you are developing locally, you can use something like ngrok to get a publc URL
|
||||
## and tunnel it to your locally running backend.
|
||||
PLATFORM_BASE_URL=http://localhost:3000
|
||||
|
||||
## Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
## This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
## This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
# integration to work.
|
||||
# AI/LLM Services
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
LLAMA_API_KEY=
|
||||
AIML_API_KEY=
|
||||
V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
@@ -85,7 +69,6 @@ GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
|
||||
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
|
||||
@@ -121,104 +104,66 @@ LINEAR_CLIENT_SECRET=
|
||||
TODOIST_CLIENT_ID=
|
||||
TODOIST_CLIENT_SECRET=
|
||||
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
# LLM
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
AIML_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
LLAMA_API_KEY=
|
||||
|
||||
# Reddit
|
||||
# Go to https://www.reddit.com/prefs/apps and create a new app
|
||||
# Choose "script" for the type
|
||||
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
|
||||
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=
|
||||
# Payment Processing
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# SMTP/Email
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
# Email Service (for sending notifications and confirmations)
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
# D-ID
|
||||
# Error Tracking
|
||||
SENTRY_DSN=
|
||||
|
||||
# Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
# Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
# This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
# This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
# Feature Flags
|
||||
LAUNCH_DARKLY_SDK_KEY=
|
||||
|
||||
# Content Generation & Media
|
||||
DID_API_KEY=
|
||||
FAL_API_KEY=
|
||||
IDEOGRAM_API_KEY=
|
||||
REPLICATE_API_KEY=
|
||||
REVID_API_KEY=
|
||||
SCREENSHOTONE_API_KEY=
|
||||
UNREAL_SPEECH_API_KEY=
|
||||
|
||||
# Open Weather Map
|
||||
# Data & Search Services
|
||||
E2B_API_KEY=
|
||||
EXA_API_KEY=
|
||||
JINA_API_KEY=
|
||||
MEM0_API_KEY=
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
|
||||
# SMTP
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Medium
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
# Google Maps
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Replicate
|
||||
REPLICATE_API_KEY=
|
||||
# Communication Services
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Ideogram
|
||||
IDEOGRAM_API_KEY=
|
||||
|
||||
# Fal
|
||||
FAL_API_KEY=
|
||||
|
||||
# Exa
|
||||
EXA_API_KEY=
|
||||
|
||||
# E2B
|
||||
E2B_API_KEY=
|
||||
|
||||
# Mem0
|
||||
MEM0_API_KEY=
|
||||
|
||||
# Nvidia
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Apollo
|
||||
# Business & Marketing Tools
|
||||
APOLLO_API_KEY=
|
||||
|
||||
# SmartLead
|
||||
SMARTLEAD_API_KEY=
|
||||
|
||||
# ZeroBounce
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
# Ayrshare
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
AYRSHARE_JWT_KEY=
|
||||
SMARTLEAD_API_KEY=
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Block Error Rate Monitoring
|
||||
BLOCK_ERROR_RATE_THRESHOLD=0.5
|
||||
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
ENABLE_FILE_LOGGING=false
|
||||
# Use to manually set the log directory
|
||||
# LOG_DIR=./logs
|
||||
|
||||
# Example Blocks Configuration
|
||||
# Set to true to enable example blocks in development
|
||||
# These blocks are disabled by default in production
|
||||
ENABLE_EXAMPLE_BLOCKS=false
|
||||
|
||||
# Cloud Storage Configuration
|
||||
# Cleanup interval for expired files (hours between cleanup runs, 1-24 hours)
|
||||
CLOUD_STORAGE_CLEANUP_INTERVAL_HOURS=6
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
.env
|
||||
database.db
|
||||
database.db-journal
|
||||
dev.db
|
||||
|
||||
@@ -8,14 +8,14 @@ WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get install -y build-essential
|
||||
RUN apt-get install -y libpq5
|
||||
RUN apt-get install -y libz-dev
|
||||
RUN apt-get install -y libssl-dev
|
||||
RUN apt-get install -y postgresql-client
|
||||
# Update package list and install build dependencies in a single layer
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y \
|
||||
build-essential \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
@@ -68,6 +68,12 @@ COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.tom
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
|
||||
@@ -43,11 +43,11 @@ def main(**kwargs):
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
ExecutionManager(),
|
||||
Scheduler(),
|
||||
NotificationManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -15,7 +14,8 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json, retry
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,7 +49,7 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
pass
|
||||
@@ -95,23 +95,14 @@ class AgentExecutorBlock(Block):
|
||||
logger=logger,
|
||||
):
|
||||
yield name, data
|
||||
except asyncio.CancelledError:
|
||||
except BaseException as e:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.warning(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} was cancelled."
|
||||
)
|
||||
except Exception as e:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.error(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e}, execution is stopped."
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e.__class__.__name__} {str(e)}; execution is stopped."
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -131,6 +122,7 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
log_id = f"Graph #{graph_id}-V{graph_version}, exec-id: {graph_exec_id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
yielded_node_exec_ids = set()
|
||||
|
||||
async for event in event_bus.listen(
|
||||
user_id=user_id,
|
||||
@@ -162,6 +154,14 @@ class AgentExecutorBlock(Block):
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if event.node_exec_id in yielded_node_exec_ids:
|
||||
logger.warning(
|
||||
f"{log_id} received duplicate event for node execution {event.node_exec_id}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yielded_node_exec_ids.add(event.node_exec_id)
|
||||
|
||||
if not event.block_id:
|
||||
logger.warning(f"{log_id} received event without block_id {event}")
|
||||
continue
|
||||
@@ -181,7 +181,7 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@retry.func_retry
|
||||
@func_retry
|
||||
async def _stop(
|
||||
self,
|
||||
graph_exec_id: str,
|
||||
@@ -197,7 +197,8 @@ class AgentExecutorBlock(Block):
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
wait_timeout=3600,
|
||||
)
|
||||
logger.info(f"Execution {log_id} stopped successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop execution {log_id}: {e}")
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Execution {log_id} stop timed out: {e}")
|
||||
|
||||
@@ -9,6 +9,24 @@ from backend.sdk import BaseModel, Credentials, Requests
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_bools(
|
||||
obj: Any,
|
||||
) -> Any: # noqa: ANN401 – allow Any for deep conversion utility
|
||||
"""Recursively walk *obj* and coerce string booleans to real booleans."""
|
||||
if isinstance(obj, str):
|
||||
lowered = obj.lower()
|
||||
if lowered == "true":
|
||||
return True
|
||||
if lowered == "false":
|
||||
return False
|
||||
return obj
|
||||
if isinstance(obj, list):
|
||||
return [_convert_bools(item) for item in obj]
|
||||
if isinstance(obj, dict):
|
||||
return {k: _convert_bools(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class WebhookFilters(BaseModel):
|
||||
dataTypes: list[str]
|
||||
changeTypes: list[str] | None = None
|
||||
@@ -579,7 +597,7 @@ async def update_table(
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables/{table_id}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=params,
|
||||
json=_convert_bools(params),
|
||||
)
|
||||
|
||||
return response.json()
|
||||
@@ -609,7 +627,7 @@ async def create_field(
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables/{table_id}/fields",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=params,
|
||||
json=_convert_bools(params),
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -633,7 +651,7 @@ async def update_field(
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables/{table_id}/fields/{field_id}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=params,
|
||||
json=_convert_bools(params),
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -691,7 +709,7 @@ async def list_records(
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
params=params,
|
||||
json=_convert_bools(params),
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -720,20 +738,22 @@ async def update_multiple_records(
|
||||
typecast: bool | None = None,
|
||||
) -> dict[str, dict[str, dict[str, str]]]:
|
||||
|
||||
params: dict[str, str | dict[str, list[str]] | list[dict[str, dict[str, str]]]] = {}
|
||||
params: dict[
|
||||
str, str | bool | dict[str, list[str]] | list[dict[str, dict[str, str]]]
|
||||
] = {}
|
||||
if perform_upsert:
|
||||
params["performUpsert"] = perform_upsert
|
||||
if return_fields_by_field_id:
|
||||
params["returnFieldsByFieldId"] = str(return_fields_by_field_id)
|
||||
if typecast:
|
||||
params["typecast"] = str(typecast)
|
||||
params["typecast"] = typecast
|
||||
|
||||
params["records"] = records
|
||||
params["records"] = [_convert_bools(record) for record in records]
|
||||
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=params,
|
||||
json=_convert_bools(params),
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -747,18 +767,20 @@ async def update_record(
|
||||
typecast: bool | None = None,
|
||||
fields: dict[str, Any] | None = None,
|
||||
) -> dict[str, dict[str, dict[str, str]]]:
|
||||
params: dict[str, str | dict[str, Any] | list[dict[str, dict[str, str]]]] = {}
|
||||
params: dict[str, str | bool | dict[str, Any] | list[dict[str, dict[str, str]]]] = (
|
||||
{}
|
||||
)
|
||||
if return_fields_by_field_id:
|
||||
params["returnFieldsByFieldId"] = str(return_fields_by_field_id)
|
||||
params["returnFieldsByFieldId"] = return_fields_by_field_id
|
||||
if typecast:
|
||||
params["typecast"] = str(typecast)
|
||||
params["typecast"] = typecast
|
||||
if fields:
|
||||
params["fields"] = fields
|
||||
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}/{record_id}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=params,
|
||||
json=_convert_bools(params),
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -779,21 +801,22 @@ async def create_record(
|
||||
len(records) <= 10
|
||||
), "Only up to 10 records can be provided when using records"
|
||||
|
||||
params: dict[str, str | dict[str, Any] | list[dict[str, Any]]] = {}
|
||||
params: dict[str, str | bool | dict[str, Any] | list[dict[str, Any]]] = {}
|
||||
if fields:
|
||||
params["fields"] = fields
|
||||
if records:
|
||||
params["records"] = records
|
||||
if return_fields_by_field_id:
|
||||
params["returnFieldsByFieldId"] = str(return_fields_by_field_id)
|
||||
params["returnFieldsByFieldId"] = return_fields_by_field_id
|
||||
if typecast:
|
||||
params["typecast"] = str(typecast)
|
||||
params["typecast"] = typecast
|
||||
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=params,
|
||||
json=_convert_bools(params),
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
@@ -850,7 +873,7 @@ async def create_webhook(
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{base_id}/webhooks",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=params,
|
||||
json=_convert_bools(params),
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -1195,7 +1218,7 @@ async def create_base(
|
||||
"Authorization": credentials.auth_header(),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=params,
|
||||
json=_convert_bools(params),
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
@@ -159,6 +159,7 @@ class AirtableOAuthHandler(BaseOAuthHandler):
|
||||
logger.info("Successfully refreshed tokens")
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
access_token=SecretStr(response.access_token),
|
||||
refresh_token=SecretStr(response.refresh_token),
|
||||
access_token_expires_at=int(time.time()) + response.expires_in,
|
||||
|
||||
@@ -4,11 +4,19 @@ from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
async def get_profile_key(user_id: str):
|
||||
user_integrations: UserIntegrations = (
|
||||
await get_database_manager_async_client().get_user_integrations(user_id)
|
||||
)
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchema):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
|
||||
@@ -6,10 +6,9 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToBlueskyBlock(Block):
|
||||
@@ -58,10 +57,12 @@ class PostToBlueskyBlock(Block):
|
||||
self,
|
||||
input_data: "PostToBlueskyBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Bluesky with Bluesky-specific options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,10 +6,14 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, CarouselItem, create_ayrshare_client
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
CarouselItem,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToFacebookBlock(Block):
|
||||
@@ -116,10 +120,11 @@ class PostToFacebookBlock(Block):
|
||||
self,
|
||||
input_data: "PostToFacebookBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Facebook with Facebook-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,10 +6,9 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToGMBBlock(Block):
|
||||
@@ -111,9 +110,10 @@ class PostToGMBBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToGMBBlock.Input", *, profile_key: SecretStr, **kwargs
|
||||
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to Google My Business with GMB-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -8,10 +8,14 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, InstagramUserTag, create_ayrshare_client
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
InstagramUserTag,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToInstagramBlock(Block):
|
||||
@@ -108,10 +112,11 @@ class PostToInstagramBlock(Block):
|
||||
self,
|
||||
input_data: "PostToInstagramBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Instagram with Instagram-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,10 +6,9 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToLinkedInBlock(Block):
|
||||
@@ -113,10 +112,11 @@ class PostToLinkedInBlock(Block):
|
||||
self,
|
||||
input_data: "PostToLinkedInBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to LinkedIn with LinkedIn-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,10 +6,14 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, PinterestCarouselOption, create_ayrshare_client
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
PinterestCarouselOption,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToPinterestBlock(Block):
|
||||
@@ -88,10 +92,11 @@ class PostToPinterestBlock(Block):
|
||||
self,
|
||||
input_data: "PostToPinterestBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Pinterest with Pinterest-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,10 +6,9 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToRedditBlock(Block):
|
||||
@@ -36,8 +35,9 @@ class PostToRedditBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToRedditBlock.Input", *, profile_key: SecretStr, **kwargs
|
||||
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,10 +6,9 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToSnapchatBlock(Block):
|
||||
@@ -63,10 +62,11 @@ class PostToSnapchatBlock(Block):
|
||||
self,
|
||||
input_data: "PostToSnapchatBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Snapchat with Snapchat-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,10 +6,9 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToTelegramBlock(Block):
|
||||
@@ -58,10 +57,11 @@ class PostToTelegramBlock(Block):
|
||||
self,
|
||||
input_data: "PostToTelegramBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Telegram with Telegram-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,10 +6,9 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToThreadsBlock(Block):
|
||||
@@ -51,10 +50,11 @@ class PostToThreadsBlock(Block):
|
||||
self,
|
||||
input_data: "PostToThreadsBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Threads with Threads-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
@@ -6,10 +8,15 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class TikTokVisibility(str, Enum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
FOLLOWERS = "followers"
|
||||
|
||||
|
||||
class PostToTikTokBlock(Block):
|
||||
@@ -21,7 +28,6 @@ class PostToTikTokBlock(Block):
|
||||
# Override post field to include TikTok-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -34,7 +40,7 @@ class PostToTikTokBlock(Block):
|
||||
|
||||
# TikTok-specific options
|
||||
auto_add_music: bool = SchemaField(
|
||||
description="Automatically add recommended music to image posts",
|
||||
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -54,17 +60,17 @@ class PostToTikTokBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
is_ai_generated: bool = SchemaField(
|
||||
description="Label content as AI-generated (video only)",
|
||||
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and can’t be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_branded_content: bool = SchemaField(
|
||||
description="Label as branded content (paid partnership)",
|
||||
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_brand_organic: bool = SchemaField(
|
||||
description="Label as brand organic content (promotional)",
|
||||
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -81,9 +87,9 @@ class PostToTikTokBlock(Block):
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
visibility: str = SchemaField(
|
||||
visibility: TikTokVisibility = SchemaField(
|
||||
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
|
||||
default="public",
|
||||
default=TikTokVisibility.PUBLIC,
|
||||
advanced=True,
|
||||
)
|
||||
draft: bool = SchemaField(
|
||||
@@ -98,7 +104,6 @@ class PostToTikTokBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
|
||||
description="Post to TikTok using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
@@ -108,9 +113,10 @@ class PostToTikTokBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToTikTokBlock.Input", *, profile_key: SecretStr, **kwargs
|
||||
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to TikTok with TikTok-specific validation and options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
@@ -160,12 +166,6 @@ class PostToTikTokBlock(Block):
|
||||
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
|
||||
return
|
||||
|
||||
# Validate visibility option
|
||||
valid_visibility = ["public", "private", "followers", "friends"]
|
||||
if input_data.visibility not in valid_visibility:
|
||||
yield "error", f"TikTok visibility must be one of: {', '.join(valid_visibility)}"
|
||||
return
|
||||
|
||||
# Check for PNG files (not supported)
|
||||
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
|
||||
if has_png:
|
||||
@@ -218,8 +218,8 @@ class PostToTikTokBlock(Block):
|
||||
if input_data.title:
|
||||
tiktok_options["title"] = input_data.title
|
||||
|
||||
if input_data.visibility != "public":
|
||||
tiktok_options["visibility"] = input_data.visibility
|
||||
if input_data.visibility != TikTokVisibility.PUBLIC:
|
||||
tiktok_options["visibility"] = input_data.visibility.value
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
|
||||
@@ -6,10 +6,9 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToXBlock(Block):
|
||||
@@ -116,10 +115,11 @@ class PostToXBlock(Block):
|
||||
self,
|
||||
input_data: "PostToXBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to X / Twitter with enhanced X-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -9,10 +9,9 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class YouTubeVisibility(str, Enum):
|
||||
@@ -138,10 +137,12 @@ class PostToYouTubeBlock(Block):
|
||||
self,
|
||||
input_data: "PostToYouTubeBlock.Input",
|
||||
*,
|
||||
profile_key: SecretStr,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to YouTube with YouTube-specific validation and options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
408
autogpt_platform/backend/backend/blocks/enrichlayer/_api.py
Normal file
408
autogpt_platform/backend/backend/blocks/enrichlayer/_api.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
API module for Enrichlayer integration.
|
||||
|
||||
This module provides a client for interacting with the Enrichlayer API,
|
||||
which allows fetching LinkedIn profile data and related information.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import enum
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class EnrichlayerAPIException(Exception):
|
||||
"""Exception raised for Enrichlayer API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class FallbackToCache(enum.Enum):
|
||||
ON_ERROR = "on-error"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
class UseCache(enum.Enum):
|
||||
IF_PRESENT = "if-present"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
class SocialMediaProfiles(BaseModel):
|
||||
"""Social media profiles model."""
|
||||
|
||||
twitter: Optional[str] = None
|
||||
facebook: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
|
||||
|
||||
class Experience(BaseModel):
|
||||
"""Experience model for LinkedIn profiles."""
|
||||
|
||||
company: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
starts_at: Optional[dict[str, int]] = None
|
||||
ends_at: Optional[dict[str, int]] = None
|
||||
company_linkedin_profile_url: Optional[str] = None
|
||||
|
||||
|
||||
class Education(BaseModel):
|
||||
"""Education model for LinkedIn profiles."""
|
||||
|
||||
school: Optional[str] = None
|
||||
degree_name: Optional[str] = None
|
||||
field_of_study: Optional[str] = None
|
||||
starts_at: Optional[dict[str, int]] = None
|
||||
ends_at: Optional[dict[str, int]] = None
|
||||
school_linkedin_profile_url: Optional[str] = None
|
||||
|
||||
|
||||
class PersonProfileResponse(BaseModel):
|
||||
"""Response model for LinkedIn person profile.
|
||||
|
||||
This model represents the response from Enrichlayer's LinkedIn profile API.
|
||||
The API returns comprehensive profile data including work experience,
|
||||
education, skills, and contact information (when available).
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"public_identifier": "johnsmith",
|
||||
"full_name": "John Smith",
|
||||
"occupation": "Software Engineer at Tech Corp",
|
||||
"experiences": [
|
||||
{
|
||||
"company": "Tech Corp",
|
||||
"title": "Software Engineer",
|
||||
"starts_at": {"year": 2020, "month": 1}
|
||||
}
|
||||
],
|
||||
"education": [...],
|
||||
"skills": ["Python", "JavaScript", ...]
|
||||
}
|
||||
"""
|
||||
|
||||
public_identifier: Optional[str] = None
|
||||
profile_pic_url: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
occupation: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_full_name: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
experiences: Optional[list[Experience]] = None
|
||||
education: Optional[list[Education]] = None
|
||||
languages: Optional[list[str]] = None
|
||||
skills: Optional[list[str]] = None
|
||||
inferred_salary: Optional[dict[str, Any]] = None
|
||||
personal_email: Optional[str] = None
|
||||
personal_contact_number: Optional[str] = None
|
||||
social_media_profiles: Optional[SocialMediaProfiles] = None
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class SimilarProfile(BaseModel):
|
||||
"""Similar profile model for LinkedIn person lookup."""
|
||||
|
||||
similarity: float
|
||||
linkedin_profile_url: str
|
||||
|
||||
|
||||
class PersonLookupResponse(BaseModel):
|
||||
"""Response model for LinkedIn person lookup.
|
||||
|
||||
This model represents the response from Enrichlayer's person lookup API.
|
||||
The API returns a LinkedIn profile URL and similarity scores when
|
||||
searching for a person by name and company.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"url": "https://www.linkedin.com/in/johnsmith/",
|
||||
"name_similarity_score": 0.95,
|
||||
"company_similarity_score": 0.88,
|
||||
"title_similarity_score": 0.75,
|
||||
"location_similarity_score": 0.60
|
||||
}
|
||||
"""
|
||||
|
||||
url: str | None = None
|
||||
name_similarity_score: float | None
|
||||
company_similarity_score: float | None
|
||||
title_similarity_score: float | None
|
||||
location_similarity_score: float | None
|
||||
last_updated: datetime.datetime | None = None
|
||||
profile: PersonProfileResponse | None = None
|
||||
|
||||
|
||||
class RoleLookupResponse(BaseModel):
|
||||
"""Response model for LinkedIn role lookup.
|
||||
|
||||
This model represents the response from Enrichlayer's role lookup API.
|
||||
The API returns LinkedIn profile data for a specific role at a company.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"linkedin_profile_url": "https://www.linkedin.com/in/johnsmith/",
|
||||
"profile_data": {...} // Full PersonProfileResponse data when enrich_profile=True
|
||||
}
|
||||
"""
|
||||
|
||||
linkedin_profile_url: Optional[str] = None
|
||||
profile_data: Optional[PersonProfileResponse] = None
|
||||
|
||||
|
||||
class ProfilePictureResponse(BaseModel):
|
||||
"""Response model for LinkedIn profile picture.
|
||||
|
||||
This model represents the response from Enrichlayer's profile picture API.
|
||||
The API returns a URL to the person's LinkedIn profile picture.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"tmp_profile_pic_url": "https://media.licdn.com/dms/image/..."
|
||||
}
|
||||
"""
|
||||
|
||||
tmp_profile_pic_url: str = Field(
|
||||
..., description="URL of the profile picture", alias="tmp_profile_pic_url"
|
||||
)
|
||||
|
||||
@property
|
||||
def profile_picture_url(self) -> str:
|
||||
"""Backward compatibility property for profile_picture_url."""
|
||||
return self.tmp_profile_pic_url
|
||||
|
||||
|
||||
class EnrichlayerClient:
|
||||
"""Client for interacting with the Enrichlayer API."""
|
||||
|
||||
API_BASE_URL = "https://enrichlayer.com/api/v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credentials: Optional[APIKeyCredentials] = None,
|
||||
custom_requests: Optional[Requests] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Enrichlayer client.
|
||||
|
||||
Args:
|
||||
credentials: The credentials to use for authentication.
|
||||
custom_requests: Custom Requests instance for testing.
|
||||
"""
|
||||
if custom_requests:
|
||||
self._requests = custom_requests
|
||||
else:
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if credentials:
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
self._requests = Requests(
|
||||
extra_headers=headers,
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
async def _handle_response(self, response) -> Any:
|
||||
"""
|
||||
Handle API response and check for errors.
|
||||
|
||||
Args:
|
||||
response: The response object from the request.
|
||||
|
||||
Returns:
|
||||
The response data.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("message", "")
|
||||
except JSONDecodeError:
|
||||
error_message = response.text
|
||||
|
||||
raise EnrichlayerAPIException(
|
||||
f"Enrichlayer API request failed ({response.status_code}): {error_message}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def fetch_profile(
|
||||
self,
|
||||
linkedin_url: str,
|
||||
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
|
||||
use_cache: UseCache = UseCache.IF_PRESENT,
|
||||
include_skills: bool = False,
|
||||
include_inferred_salary: bool = False,
|
||||
include_personal_email: bool = False,
|
||||
include_personal_contact_number: bool = False,
|
||||
include_social_media: bool = False,
|
||||
include_extra: bool = False,
|
||||
) -> PersonProfileResponse:
|
||||
"""
|
||||
Fetch a LinkedIn profile with optional parameters.
|
||||
|
||||
Args:
|
||||
linkedin_url: The LinkedIn profile URL to fetch.
|
||||
fallback_to_cache: Cache usage if live fetch fails ('on-error' or 'never').
|
||||
use_cache: Cache utilization ('if-present' or 'never').
|
||||
include_skills: Whether to include skills data.
|
||||
include_inferred_salary: Whether to include inferred salary data.
|
||||
include_personal_email: Whether to include personal email.
|
||||
include_personal_contact_number: Whether to include personal contact number.
|
||||
include_social_media: Whether to include social media profiles.
|
||||
include_extra: Whether to include additional data.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile data.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"url": linkedin_url,
|
||||
"fallback_to_cache": fallback_to_cache.value.lower(),
|
||||
"use_cache": use_cache.value.lower(),
|
||||
}
|
||||
|
||||
if include_skills:
|
||||
params["skills"] = "include"
|
||||
if include_inferred_salary:
|
||||
params["inferred_salary"] = "include"
|
||||
if include_personal_email:
|
||||
params["personal_email"] = "include"
|
||||
if include_personal_contact_number:
|
||||
params["personal_contact_number"] = "include"
|
||||
if include_social_media:
|
||||
params["twitter_profile_id"] = "include"
|
||||
params["facebook_profile_id"] = "include"
|
||||
params["github_profile_id"] = "include"
|
||||
if include_extra:
|
||||
params["extra"] = "include"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/profile", params=params
|
||||
)
|
||||
return PersonProfileResponse(**await self._handle_response(response))
|
||||
|
||||
async def lookup_person(
|
||||
self,
|
||||
first_name: str,
|
||||
company_domain: str,
|
||||
last_name: str | None = None,
|
||||
location: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
include_similarity_checks: bool = False,
|
||||
enrich_profile: bool = False,
|
||||
) -> PersonLookupResponse:
|
||||
"""
|
||||
Look up a LinkedIn profile by person's information.
|
||||
|
||||
Args:
|
||||
first_name: The person's first name.
|
||||
last_name: The person's last name.
|
||||
company_domain: The domain of the company they work for.
|
||||
location: The person's location.
|
||||
title: The person's job title.
|
||||
include_similarity_checks: Whether to include similarity checks.
|
||||
enrich_profile: Whether to enrich the profile.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile lookup result.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {"first_name": first_name, "company_domain": company_domain}
|
||||
|
||||
if last_name:
|
||||
params["last_name"] = last_name
|
||||
if location:
|
||||
params["location"] = location
|
||||
if title:
|
||||
params["title"] = title
|
||||
if include_similarity_checks:
|
||||
params["similarity_checks"] = "include"
|
||||
if enrich_profile:
|
||||
params["enrich_profile"] = "enrich"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/profile/resolve", params=params
|
||||
)
|
||||
return PersonLookupResponse(**await self._handle_response(response))
|
||||
|
||||
async def lookup_role(
|
||||
self, role: str, company_name: str, enrich_profile: bool = False
|
||||
) -> RoleLookupResponse:
|
||||
"""
|
||||
Look up a LinkedIn profile by role in a company.
|
||||
|
||||
Args:
|
||||
role: The role title (e.g., CEO, CTO).
|
||||
company_name: The name of the company.
|
||||
enrich_profile: Whether to enrich the profile.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile lookup result.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"role": role,
|
||||
"company_name": company_name,
|
||||
}
|
||||
|
||||
if enrich_profile:
|
||||
params["enrich_profile"] = "enrich"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/find/company/role", params=params
|
||||
)
|
||||
return RoleLookupResponse(**await self._handle_response(response))
|
||||
|
||||
async def get_profile_picture(
|
||||
self, linkedin_profile_url: str
|
||||
) -> ProfilePictureResponse:
|
||||
"""
|
||||
Get a LinkedIn profile picture URL.
|
||||
|
||||
Args:
|
||||
linkedin_profile_url: The LinkedIn profile URL.
|
||||
|
||||
Returns:
|
||||
The profile picture URL.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"linkedin_person_profile_url": linkedin_profile_url,
|
||||
}
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/person/profile-picture", params=params
|
||||
)
|
||||
return ProfilePictureResponse(**await self._handle_response(response))
|
||||
34
autogpt_platform/backend/backend/blocks/enrichlayer/_auth.py
Normal file
34
autogpt_platform/backend/backend/blocks/enrichlayer/_auth.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Authentication module for Enrichlayer API integration.
|
||||
|
||||
This module provides credential types and test credentials for the Enrichlayer API.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Define the type of credentials input expected for Enrichlayer API
|
||||
EnrichlayerCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.ENRICHLAYER], Literal["api_key"]
|
||||
]
|
||||
|
||||
# Mock credentials for testing Enrichlayer API integration
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="1234a567-89bc-4def-ab12-3456cdef7890",
|
||||
provider="enrichlayer",
|
||||
api_key=SecretStr("mock-enrichlayer-api-key"),
|
||||
title="Mock Enrichlayer API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
# Dictionary representation of test credentials for input fields
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
527
autogpt_platform/backend/backend/blocks/enrichlayer/linkedin.py
Normal file
527
autogpt_platform/backend/backend/blocks/enrichlayer/linkedin.py
Normal file
@@ -0,0 +1,527 @@
|
||||
"""
|
||||
Block definitions for Enrichlayer API integration.
|
||||
|
||||
This module implements blocks for interacting with the Enrichlayer API,
|
||||
which provides access to LinkedIn profile data and related information.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import (
|
||||
EnrichlayerClient,
|
||||
Experience,
|
||||
FallbackToCache,
|
||||
PersonLookupResponse,
|
||||
PersonProfileResponse,
|
||||
RoleLookupResponse,
|
||||
UseCache,
|
||||
)
|
||||
from ._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, EnrichlayerCredentialsInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetLinkedinProfileBlock(Block):
|
||||
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfileBlock."""
|
||||
|
||||
linkedin_url: str = SchemaField(
|
||||
description="LinkedIn profile URL to fetch data from",
|
||||
placeholder="https://www.linkedin.com/in/username/",
|
||||
)
|
||||
fallback_to_cache: FallbackToCache = SchemaField(
|
||||
description="Cache usage if live fetch fails",
|
||||
default=FallbackToCache.ON_ERROR,
|
||||
advanced=True,
|
||||
)
|
||||
use_cache: UseCache = SchemaField(
|
||||
description="Cache utilization strategy",
|
||||
default=UseCache.IF_PRESENT,
|
||||
advanced=True,
|
||||
)
|
||||
include_skills: bool = SchemaField(
|
||||
description="Include skills data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_inferred_salary: bool = SchemaField(
|
||||
description="Include inferred salary data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_personal_email: bool = SchemaField(
|
||||
description="Include personal email",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_personal_contact_number: bool = SchemaField(
|
||||
description="Include personal contact number",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_social_media: bool = SchemaField(
|
||||
description="Include social media profiles",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_extra: bool = SchemaField(
|
||||
description="Include additional data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfileBlock."""
|
||||
|
||||
profile: PersonProfileResponse = SchemaField(
|
||||
description="LinkedIn profile data"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfileBlock."""
|
||||
super().__init__(
|
||||
id="f6e0ac73-4f1d-4acb-b4b7-b67066c5984e",
|
||||
description="Fetch LinkedIn profile data using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=GetLinkedinProfileBlock.Input,
|
||||
output_schema=GetLinkedinProfileBlock.Output,
|
||||
test_input={
|
||||
"linkedin_url": "https://www.linkedin.com/in/williamhgates/",
|
||||
"include_skills": True,
|
||||
"include_social_media": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"profile",
|
||||
PersonProfileResponse(
|
||||
public_identifier="williamhgates",
|
||||
full_name="Bill Gates",
|
||||
occupation="Co-chair at Bill & Melinda Gates Foundation",
|
||||
experiences=[
|
||||
Experience(
|
||||
company="Bill & Melinda Gates Foundation",
|
||||
title="Co-chair",
|
||||
starts_at={"year": 2000},
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_fetch_profile": lambda *args, **kwargs: PersonProfileResponse(
|
||||
public_identifier="williamhgates",
|
||||
full_name="Bill Gates",
|
||||
occupation="Co-chair at Bill & Melinda Gates Foundation",
|
||||
experiences=[
|
||||
Experience(
|
||||
company="Bill & Melinda Gates Foundation",
|
||||
title="Co-chair",
|
||||
starts_at={"year": 2000},
|
||||
)
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_profile(
|
||||
credentials: APIKeyCredentials,
|
||||
linkedin_url: str,
|
||||
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
|
||||
use_cache: UseCache = UseCache.IF_PRESENT,
|
||||
include_skills: bool = False,
|
||||
include_inferred_salary: bool = False,
|
||||
include_personal_email: bool = False,
|
||||
include_personal_contact_number: bool = False,
|
||||
include_social_media: bool = False,
|
||||
include_extra: bool = False,
|
||||
):
|
||||
client = EnrichlayerClient(credentials)
|
||||
profile = await client.fetch_profile(
|
||||
linkedin_url=linkedin_url,
|
||||
fallback_to_cache=fallback_to_cache,
|
||||
use_cache=use_cache,
|
||||
include_skills=include_skills,
|
||||
include_inferred_salary=include_inferred_salary,
|
||||
include_personal_email=include_personal_email,
|
||||
include_personal_contact_number=include_personal_contact_number,
|
||||
include_social_media=include_social_media,
|
||||
include_extra=include_extra,
|
||||
)
|
||||
return profile
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to fetch LinkedIn profile data.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
profile = await self._fetch_profile(
|
||||
credentials=credentials,
|
||||
linkedin_url=input_data.linkedin_url,
|
||||
fallback_to_cache=input_data.fallback_to_cache,
|
||||
use_cache=input_data.use_cache,
|
||||
include_skills=input_data.include_skills,
|
||||
include_inferred_salary=input_data.include_inferred_salary,
|
||||
include_personal_email=input_data.include_personal_email,
|
||||
include_personal_contact_number=input_data.include_personal_contact_number,
|
||||
include_social_media=input_data.include_social_media,
|
||||
include_extra=input_data.include_extra,
|
||||
)
|
||||
yield "profile", profile
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class LinkedinPersonLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
first_name: str = SchemaField(
|
||||
description="Person's first name",
|
||||
placeholder="John",
|
||||
advanced=False,
|
||||
)
|
||||
last_name: str | None = SchemaField(
|
||||
description="Person's last name",
|
||||
placeholder="Doe",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
company_domain: str = SchemaField(
|
||||
description="Domain of the company they work for (optional)",
|
||||
placeholder="example.com",
|
||||
advanced=False,
|
||||
)
|
||||
location: Optional[str] = SchemaField(
|
||||
description="Person's location (optional)",
|
||||
placeholder="San Francisco",
|
||||
default=None,
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Person's job title (optional)",
|
||||
placeholder="CEO",
|
||||
default=None,
|
||||
)
|
||||
include_similarity_checks: bool = SchemaField(
|
||||
description="Include similarity checks",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
enrich_profile: bool = SchemaField(
|
||||
description="Enrich the profile with additional data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
lookup_result: PersonLookupResponse = SchemaField(
|
||||
description="LinkedIn profile lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinPersonLookupBlock."""
|
||||
super().__init__(
|
||||
id="d237a98a-5c4b-4a1c-b9e3-e6f9a6c81df7",
|
||||
description="Look up LinkedIn profiles by person information using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=LinkedinPersonLookupBlock.Input,
|
||||
output_schema=LinkedinPersonLookupBlock.Output,
|
||||
test_input={
|
||||
"first_name": "Bill",
|
||||
"last_name": "Gates",
|
||||
"company_domain": "gatesfoundation.org",
|
||||
"include_similarity_checks": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"lookup_result",
|
||||
PersonLookupResponse(
|
||||
url="https://www.linkedin.com/in/williamhgates/",
|
||||
name_similarity_score=0.93,
|
||||
company_similarity_score=0.83,
|
||||
title_similarity_score=0.3,
|
||||
location_similarity_score=0.20,
|
||||
),
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_lookup_person": lambda *args, **kwargs: PersonLookupResponse(
|
||||
url="https://www.linkedin.com/in/williamhgates/",
|
||||
name_similarity_score=0.93,
|
||||
company_similarity_score=0.83,
|
||||
title_similarity_score=0.3,
|
||||
location_similarity_score=0.20,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _lookup_person(
|
||||
credentials: APIKeyCredentials,
|
||||
first_name: str,
|
||||
company_domain: str,
|
||||
last_name: str | None = None,
|
||||
location: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
include_similarity_checks: bool = False,
|
||||
enrich_profile: bool = False,
|
||||
):
|
||||
client = EnrichlayerClient(credentials=credentials)
|
||||
lookup_result = await client.lookup_person(
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
company_domain=company_domain,
|
||||
location=location,
|
||||
title=title,
|
||||
include_similarity_checks=include_similarity_checks,
|
||||
enrich_profile=enrich_profile,
|
||||
)
|
||||
return lookup_result
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to look up LinkedIn profiles.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
lookup_result = await self._lookup_person(
|
||||
credentials=credentials,
|
||||
first_name=input_data.first_name,
|
||||
last_name=input_data.last_name,
|
||||
company_domain=input_data.company_domain,
|
||||
location=input_data.location,
|
||||
title=input_data.title,
|
||||
include_similarity_checks=input_data.include_similarity_checks,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
yield "lookup_result", lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class LinkedinRoleLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role: str = SchemaField(
|
||||
description="Role title (e.g., CEO, CTO)",
|
||||
placeholder="CEO",
|
||||
)
|
||||
company_name: str = SchemaField(
|
||||
description="Name of the company",
|
||||
placeholder="Microsoft",
|
||||
)
|
||||
enrich_profile: bool = SchemaField(
|
||||
description="Enrich the profile with additional data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role_lookup_result: RoleLookupResponse = SchemaField(
|
||||
description="LinkedIn role lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinRoleLookupBlock."""
|
||||
super().__init__(
|
||||
id="3b9fc742-06d4-49c7-b5ce-7e302dd7c8a7",
|
||||
description="Look up LinkedIn profiles by role in a company using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=LinkedinRoleLookupBlock.Input,
|
||||
output_schema=LinkedinRoleLookupBlock.Output,
|
||||
test_input={
|
||||
"role": "Co-chair",
|
||||
"company_name": "Gates Foundation",
|
||||
"enrich_profile": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"role_lookup_result",
|
||||
RoleLookupResponse(
|
||||
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
|
||||
),
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_lookup_role": lambda *args, **kwargs: RoleLookupResponse(
|
||||
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _lookup_role(
|
||||
credentials: APIKeyCredentials,
|
||||
role: str,
|
||||
company_name: str,
|
||||
enrich_profile: bool = False,
|
||||
):
|
||||
client = EnrichlayerClient(credentials=credentials)
|
||||
role_lookup_result = await client.lookup_role(
|
||||
role=role,
|
||||
company_name=company_name,
|
||||
enrich_profile=enrich_profile,
|
||||
)
|
||||
return role_lookup_result
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to look up LinkedIn profiles by role.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
role_lookup_result = await self._lookup_role(
|
||||
credentials=credentials,
|
||||
role=input_data.role,
|
||||
company_name=input_data.company_name,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
yield "role_lookup_result", role_lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up role in company: {str(e)}")
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GetLinkedinProfilePictureBlock(Block):
|
||||
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
linkedin_profile_url: str = SchemaField(
|
||||
description="LinkedIn profile URL",
|
||||
placeholder="https://www.linkedin.com/in/username/",
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
profile_picture_url: MediaFileType = SchemaField(
|
||||
description="LinkedIn profile picture URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfilePictureBlock."""
|
||||
super().__init__(
|
||||
id="68d5a942-9b3f-4e9a-b7c1-d96ea4321f0d",
|
||||
description="Get LinkedIn profile pictures using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=GetLinkedinProfilePictureBlock.Input,
|
||||
output_schema=GetLinkedinProfilePictureBlock.Output,
|
||||
test_input={
|
||||
"linkedin_profile_url": "https://www.linkedin.com/in/williamhgates/",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"profile_picture_url",
|
||||
"https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_get_profile_picture": lambda *args, **kwargs: "https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _get_profile_picture(
|
||||
credentials: APIKeyCredentials, linkedin_profile_url: str
|
||||
):
|
||||
client = EnrichlayerClient(credentials=credentials)
|
||||
profile_picture_response = await client.get_profile_picture(
|
||||
linkedin_profile_url=linkedin_profile_url,
|
||||
)
|
||||
return profile_picture_response.profile_picture_url
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to get LinkedIn profile pictures.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
profile_picture = await self._get_profile_picture(
|
||||
credentials=credentials,
|
||||
linkedin_profile_url=input_data.linkedin_profile_url,
|
||||
)
|
||||
yield "profile_picture_url", profile_picture
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profile picture: {str(e)}")
|
||||
yield "error", str(e)
|
||||
247
autogpt_platform/backend/backend/blocks/exa/model.py
Normal file
247
autogpt_platform/backend/backend/blocks/exa/model.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Enum definitions based on available options
|
||||
class WebsetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class WebsetSearchStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known, based on example it's "created"
|
||||
|
||||
|
||||
class ImportStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
CSV = "csv"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
ENABLED = "enabled"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
SEARCH = "search"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorRunStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class CanceledReason(str, Enum):
|
||||
WEBSET_DELETED = "webset_deleted"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class FailedReason(str, Enum):
|
||||
INVALID_FORMAT = "invalid_format"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class Confidence(str, Enum):
|
||||
HIGH = "high"
|
||||
# Add more if known
|
||||
|
||||
|
||||
# Nested models
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Criterion(BaseModel):
|
||||
description: str
|
||||
successRate: Optional[int] = None
|
||||
|
||||
|
||||
class ExcludeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
definition: str
|
||||
limit: Optional[float] = None
|
||||
|
||||
|
||||
class ScopeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
relationship: Optional[Relationship] = None
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
found: int
|
||||
analyzed: int
|
||||
completion: int
|
||||
timeLeft: int
|
||||
|
||||
|
||||
class Bounds(BaseModel):
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class Expected(BaseModel):
|
||||
total: int
|
||||
confidence: str = Field(default="high") # Use str or Confidence enum
|
||||
bounds: Bounds
|
||||
|
||||
|
||||
class Recall(BaseModel):
|
||||
expected: Expected
|
||||
reasoning: str
|
||||
|
||||
|
||||
class WebsetSearch(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_search")
|
||||
status: str = Field(default="created") # Or use WebsetSearchStatus
|
||||
websetId: str
|
||||
query: str
|
||||
entity: Entity
|
||||
criteria: List[Criterion]
|
||||
count: int
|
||||
behavior: str = Field(default="override")
|
||||
exclude: List[ExcludeItem]
|
||||
scope: List[ScopeItem]
|
||||
progress: Progress
|
||||
recall: Recall
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
canceledAt: Optional[datetime] = None
|
||||
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class ImportEntity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="import")
|
||||
status: str = Field(default="pending") # Or use ImportStatus
|
||||
format: str = Field(default="csv") # Or use ImportFormat
|
||||
entity: ImportEntity
|
||||
title: str
|
||||
count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
|
||||
failedAt: Optional[datetime] = None
|
||||
failedMessage: Optional[str] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Option(BaseModel):
|
||||
label: str
|
||||
|
||||
|
||||
class WebsetEnrichment(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_enrichment")
|
||||
status: str = Field(default="pending") # Or use EnrichmentStatus
|
||||
websetId: str
|
||||
title: str
|
||||
description: str
|
||||
format: str = Field(default="text") # Or use EnrichmentFormat
|
||||
options: List[Option]
|
||||
instructions: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Cadence(BaseModel):
|
||||
cron: str
|
||||
timezone: str = Field(default="Etc/UTC")
|
||||
|
||||
|
||||
class BehaviorConfig(BaseModel):
|
||||
query: Optional[str] = None
|
||||
criteria: Optional[List[Criterion]] = None
|
||||
entity: Optional[Entity] = None
|
||||
count: Optional[int] = None
|
||||
behavior: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class Behavior(BaseModel):
|
||||
type: str = Field(default="search") # Or use MonitorBehaviorType
|
||||
config: BehaviorConfig
|
||||
|
||||
|
||||
class MonitorRun(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor_run")
|
||||
status: str = Field(default="created") # Or use MonitorRunStatus
|
||||
monitorId: str
|
||||
type: str = Field(default="search")
|
||||
completedAt: Optional[datetime] = None
|
||||
failedAt: Optional[datetime] = None
|
||||
failedReason: Optional[str] = None
|
||||
canceledAt: Optional[datetime] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Monitor(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor")
|
||||
status: str = Field(default="enabled") # Or use MonitorStatus
|
||||
websetId: str
|
||||
cadence: Cadence
|
||||
behavior: Behavior
|
||||
lastRun: Optional[MonitorRun] = None
|
||||
nextRunAt: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset")
|
||||
status: WebsetStatus
|
||||
externalId: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
searches: List[WebsetSearch]
|
||||
imports: List[Import]
|
||||
enrichments: List[WebsetEnrichment]
|
||||
monitors: List[Monitor]
|
||||
streams: List[Any]
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ListWebsets(BaseModel):
|
||||
data: List[Webset]
|
||||
hasMore: bool
|
||||
nextCursor: Optional[str] = None
|
||||
@@ -114,6 +114,7 @@ class ExaWebsetWebhookBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
|
||||
description="Receive webhook notifications for Exa webset events",
|
||||
categories={BlockCategory.INPUT},
|
||||
|
||||
@@ -1,7 +1,33 @@
|
||||
from typing import Any, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Optional
|
||||
|
||||
from exa_py import Exa
|
||||
from exa_py.websets.types import (
|
||||
CreateCriterionParameters,
|
||||
CreateEnrichmentParameters,
|
||||
CreateWebsetParameters,
|
||||
CreateWebsetParametersSearch,
|
||||
ExcludeItem,
|
||||
Format,
|
||||
ImportItem,
|
||||
ImportSource,
|
||||
Option,
|
||||
ScopeItem,
|
||||
ScopeRelationship,
|
||||
ScopeSourceType,
|
||||
WebsetArticleEntity,
|
||||
WebsetCompanyEntity,
|
||||
WebsetCustomEntity,
|
||||
WebsetPersonEntity,
|
||||
WebsetResearchPaperEntity,
|
||||
WebsetStatus,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
@@ -12,7 +38,69 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
IMPORT = "import"
|
||||
WEBSET = "webset"
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
DATE = "date"
|
||||
NUMBER = "number"
|
||||
OPTIONS = "options"
|
||||
EMAIL = "email"
|
||||
PHONE = "phone"
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
status: WebsetStatus | None = Field(..., title="WebsetStatus")
|
||||
"""
|
||||
The status of the webset
|
||||
"""
|
||||
external_id: Annotated[Optional[str], Field(alias="externalId")] = None
|
||||
"""
|
||||
The external identifier for the webset
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
searches: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The searches that have been performed on the webset.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
enrichments: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The Enrichments to apply to the Webset Items.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
monitors: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The Monitors for the Webset.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = {}
|
||||
"""
|
||||
Set of key-value pairs you want to associate with this object.
|
||||
"""
|
||||
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
|
||||
"""
|
||||
The date and time the webset was created
|
||||
"""
|
||||
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
|
||||
"""
|
||||
The date and time the webset was last updated
|
||||
"""
|
||||
|
||||
|
||||
class ExaCreateWebsetBlock(Block):
|
||||
@@ -20,40 +108,121 @@ class ExaCreateWebsetBlock(Block):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
search: WebsetSearchConfig = SchemaField(
|
||||
description="Initial search configuration for the Webset"
|
||||
|
||||
# Search parameters (flattened)
|
||||
search_query: str = SchemaField(
|
||||
description="Your search query. Use this to describe what you are looking for. Any URL provided will be crawled and used as context for the search.",
|
||||
placeholder="Marketing agencies based in the US, that focus on consumer products",
|
||||
)
|
||||
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
|
||||
default=None,
|
||||
description="Enrichments to apply to Webset items",
|
||||
search_count: Optional[int] = SchemaField(
|
||||
default=10,
|
||||
description="Number of items the search will attempt to find. The actual number of items found may be less than this number depending on the search complexity.",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
search_entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Entity type: 'company', 'person', 'article', 'research_paper', or 'custom'. If not provided, we automatically detect the entity from the query.",
|
||||
advanced=True,
|
||||
)
|
||||
search_entity_description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type (required when search_entity_type is 'custom')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search criteria (flattened)
|
||||
search_criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of criteria descriptions that every item will be evaluated against. If not provided, we automatically detect the criteria from the query.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search exclude sources (flattened)
|
||||
search_exclude_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs (imports or websets) to exclude from search results",
|
||||
advanced=True,
|
||||
)
|
||||
search_exclude_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to exclude sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search scope sources (flattened)
|
||||
search_scope_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs (imports or websets) to limit search scope to",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to scope sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_relationships: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of relationship definitions for hop searches (optional, one per scope source)",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_relationship_limits: list[int] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of limits on the number of related entities to find (optional, one per scope relationship)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Import parameters (flattened)
|
||||
import_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs to import from",
|
||||
advanced=True,
|
||||
)
|
||||
import_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to import sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Enrichment parameters (flattened)
|
||||
enrichment_descriptions: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of enrichment task descriptions to perform on each webset item",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_formats: list[EnrichmentFormat] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of formats for enrichment responses ('text', 'date', 'number', 'options', 'email', 'phone'). If not specified, we automatically select the best format.",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_options: list[list[str]] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of option lists for enrichments with 'options' format. Each inner list contains the option labels.",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_metadata: list[dict] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of metadata dictionaries for enrichments",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Webset metadata
|
||||
external_id: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="External identifier for the webset",
|
||||
description="External identifier for the webset. You can use this to reference the webset by your own internal identifiers.",
|
||||
placeholder="my-webset-123",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
default_factory=dict,
|
||||
description="Key-value pairs to associate with this webset",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(
|
||||
webset: Webset = SchemaField(
|
||||
description="The unique identifier for the created webset"
|
||||
)
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -67,44 +236,171 @@ class ExaCreateWebsetBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the payload
|
||||
payload: dict[str, Any] = {
|
||||
"search": input_data.search.model_dump(exclude_none=True),
|
||||
}
|
||||
exa = Exa(credentials.api_key.get_secret_value())
|
||||
|
||||
# Convert enrichments to API format
|
||||
if input_data.enrichments:
|
||||
enrichments_data = []
|
||||
for enrichment in input_data.enrichments:
|
||||
enrichments_data.append(enrichment.model_dump(exclude_none=True))
|
||||
payload["enrichments"] = enrichments_data
|
||||
# ------------------------------------------------------------
|
||||
# Build entity (if explicitly provided)
|
||||
# ------------------------------------------------------------
|
||||
entity = None
|
||||
if input_data.search_entity_type == SearchEntityType.COMPANY:
|
||||
entity = WebsetCompanyEntity(type="company")
|
||||
elif input_data.search_entity_type == SearchEntityType.PERSON:
|
||||
entity = WebsetPersonEntity(type="person")
|
||||
elif input_data.search_entity_type == SearchEntityType.ARTICLE:
|
||||
entity = WebsetArticleEntity(type="article")
|
||||
elif input_data.search_entity_type == SearchEntityType.RESEARCH_PAPER:
|
||||
entity = WebsetResearchPaperEntity(type="research_paper")
|
||||
elif (
|
||||
input_data.search_entity_type == SearchEntityType.CUSTOM
|
||||
and input_data.search_entity_description
|
||||
):
|
||||
entity = WebsetCustomEntity(
|
||||
type="custom", description=input_data.search_entity_description
|
||||
)
|
||||
|
||||
if input_data.external_id:
|
||||
payload["externalId"] = input_data.external_id
|
||||
# ------------------------------------------------------------
|
||||
# Build criteria list
|
||||
# ------------------------------------------------------------
|
||||
criteria = None
|
||||
if input_data.search_criteria:
|
||||
criteria = [
|
||||
CreateCriterionParameters(description=item)
|
||||
for item in input_data.search_criteria
|
||||
]
|
||||
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
# ------------------------------------------------------------
|
||||
# Build exclude sources list
|
||||
# ------------------------------------------------------------
|
||||
exclude_items = None
|
||||
if input_data.search_exclude_sources:
|
||||
exclude_items = []
|
||||
for idx, src_id in enumerate(input_data.search_exclude_sources):
|
||||
src_type = None
|
||||
if input_data.search_exclude_types and idx < len(
|
||||
input_data.search_exclude_types
|
||||
):
|
||||
src_type = input_data.search_exclude_types[idx]
|
||||
# Default to IMPORT if type missing
|
||||
if src_type == SearchType.WEBSET:
|
||||
source_enum = ImportSource.webset
|
||||
else:
|
||||
source_enum = ImportSource.import_
|
||||
exclude_items.append(ExcludeItem(source=source_enum, id=src_id))
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# ------------------------------------------------------------
|
||||
# Build scope list
|
||||
# ------------------------------------------------------------
|
||||
scope_items = None
|
||||
if input_data.search_scope_sources:
|
||||
scope_items = []
|
||||
for idx, src_id in enumerate(input_data.search_scope_sources):
|
||||
src_type = None
|
||||
if input_data.search_scope_types and idx < len(
|
||||
input_data.search_scope_types
|
||||
):
|
||||
src_type = input_data.search_scope_types[idx]
|
||||
relationship = None
|
||||
if input_data.search_scope_relationships and idx < len(
|
||||
input_data.search_scope_relationships
|
||||
):
|
||||
rel_def = input_data.search_scope_relationships[idx]
|
||||
lim = None
|
||||
if input_data.search_scope_relationship_limits and idx < len(
|
||||
input_data.search_scope_relationship_limits
|
||||
):
|
||||
lim = input_data.search_scope_relationship_limits[idx]
|
||||
relationship = ScopeRelationship(definition=rel_def, limit=lim)
|
||||
if src_type == SearchType.WEBSET:
|
||||
src_enum = ScopeSourceType.webset
|
||||
else:
|
||||
src_enum = ScopeSourceType.import_
|
||||
scope_items.append(
|
||||
ScopeItem(source=src_enum, id=src_id, relationship=relationship)
|
||||
)
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "created_at", data.get("createdAt", "")
|
||||
# ------------------------------------------------------------
|
||||
# Assemble search parameters (only if a query is provided)
|
||||
# ------------------------------------------------------------
|
||||
search_params = None
|
||||
if input_data.search_query:
|
||||
search_params = CreateWebsetParametersSearch(
|
||||
query=input_data.search_query,
|
||||
count=input_data.search_count,
|
||||
entity=entity,
|
||||
criteria=criteria,
|
||||
exclude=exclude_items,
|
||||
scope=scope_items,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "created_at", ""
|
||||
# ------------------------------------------------------------
|
||||
# Build imports list
|
||||
# ------------------------------------------------------------
|
||||
imports_params = None
|
||||
if input_data.import_sources:
|
||||
imports_params = []
|
||||
for idx, src_id in enumerate(input_data.import_sources):
|
||||
src_type = None
|
||||
if input_data.import_types and idx < len(input_data.import_types):
|
||||
src_type = input_data.import_types[idx]
|
||||
if src_type == SearchType.WEBSET:
|
||||
source_enum = ImportSource.webset
|
||||
else:
|
||||
source_enum = ImportSource.import_
|
||||
imports_params.append(ImportItem(source=source_enum, id=src_id))
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build enrichment list
|
||||
# ------------------------------------------------------------
|
||||
enrichments_params = None
|
||||
if input_data.enrichment_descriptions:
|
||||
enrichments_params = []
|
||||
for idx, desc in enumerate(input_data.enrichment_descriptions):
|
||||
fmt = None
|
||||
if input_data.enrichment_formats and idx < len(
|
||||
input_data.enrichment_formats
|
||||
):
|
||||
fmt_enum = input_data.enrichment_formats[idx]
|
||||
if fmt_enum is not None:
|
||||
fmt = Format(
|
||||
fmt_enum.value if isinstance(fmt_enum, Enum) else fmt_enum
|
||||
)
|
||||
options_list = None
|
||||
if input_data.enrichment_options and idx < len(
|
||||
input_data.enrichment_options
|
||||
):
|
||||
raw_opts = input_data.enrichment_options[idx]
|
||||
if raw_opts:
|
||||
options_list = [Option(label=o) for o in raw_opts]
|
||||
metadata_obj = None
|
||||
if input_data.enrichment_metadata and idx < len(
|
||||
input_data.enrichment_metadata
|
||||
):
|
||||
metadata_obj = input_data.enrichment_metadata[idx]
|
||||
enrichments_params.append(
|
||||
CreateEnrichmentParameters(
|
||||
description=desc,
|
||||
format=fmt,
|
||||
options=options_list,
|
||||
metadata=metadata_obj,
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Create the webset
|
||||
# ------------------------------------------------------------
|
||||
webset = exa.websets.create(
|
||||
params=CreateWebsetParameters(
|
||||
search=search_params,
|
||||
imports=imports_params,
|
||||
enrichments=enrichments_params,
|
||||
external_id=input_data.external_id,
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Use alias field names returned from Exa SDK so that nested models validate correctly
|
||||
yield "webset", Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
|
||||
class ExaUpdateWebsetBlock(Block):
|
||||
@@ -183,6 +479,11 @@ class ExaListWebsetsBlock(Block):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
trigger: Any | None = SchemaField(
|
||||
default=None,
|
||||
description="Trigger for the webset, value is ignored!",
|
||||
advanced=False,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
@@ -197,7 +498,9 @@ class ExaListWebsetsBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
websets: list = SchemaField(description="List of websets", default_factory=list)
|
||||
websets: list[Webset] = SchemaField(
|
||||
description="List of websets", default_factory=list
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more results to paginate through",
|
||||
default=False,
|
||||
@@ -255,9 +558,6 @@ class ExaGetWebsetBlock(Block):
|
||||
description="The ID or external ID of the Webset to retrieve",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
expand_items: bool = SchemaField(
|
||||
default=False, description="Include items in the response", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
@@ -309,12 +609,8 @@ class ExaGetWebsetBlock(Block):
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params = {}
|
||||
if input_data.expand_items:
|
||||
params["expand[]"] = "items"
|
||||
|
||||
try:
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
response = await Requests().get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
|
||||
5
autogpt_platform/backend/backend/blocks/firecrawl/extract.py
Normal file → Executable file
5
autogpt_platform/backend/backend/blocks/firecrawl/extract.py
Normal file → Executable file
@@ -29,8 +29,8 @@ class FirecrawlExtractBlock(Block):
|
||||
prompt: str | None = SchemaField(
|
||||
description="The prompt to use for the crawl", default=None, advanced=False
|
||||
)
|
||||
output_schema: str | None = SchemaField(
|
||||
description="A more rigid structure if you already know the JSON layout.",
|
||||
output_schema: dict | None = SchemaField(
|
||||
description="A Json Schema describing the output structure if more rigid structure is desired.",
|
||||
default=None,
|
||||
)
|
||||
enable_web_search: bool = SchemaField(
|
||||
@@ -56,7 +56,6 @@ class FirecrawlExtractBlock(Block):
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
|
||||
388
autogpt_platform/backend/backend/blocks/github/ci.py
Normal file
388
autogpt_platform/backend/backend/blocks/github/ci.py
Normal file
@@ -0,0 +1,388 @@
|
||||
import logging
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckRunStatus(Enum):
|
||||
QUEUED = "queued"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
class CheckRunConclusion(Enum):
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
NEUTRAL = "neutral"
|
||||
CANCELLED = "cancelled"
|
||||
SKIPPED = "skipped"
|
||||
TIMED_OUT = "timed_out"
|
||||
ACTION_REQUIRED = "action_required"
|
||||
|
||||
|
||||
class GithubGetCIResultsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
target: str | int = SchemaField(
|
||||
description="Commit SHA or PR number to get CI results for",
|
||||
placeholder="abc123def or 123",
|
||||
)
|
||||
search_pattern: Optional[str] = SchemaField(
|
||||
description="Optional regex pattern to search for in CI logs (e.g., error messages, file names)",
|
||||
placeholder=".*error.*|.*warning.*",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
check_name_filter: Optional[str] = SchemaField(
|
||||
description="Optional filter for specific check names (supports wildcards)",
|
||||
placeholder="*lint* or build-*",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CheckRunItem(TypedDict, total=False):
|
||||
id: int
|
||||
name: str
|
||||
status: str
|
||||
conclusion: Optional[str]
|
||||
started_at: Optional[str]
|
||||
completed_at: Optional[str]
|
||||
html_url: str
|
||||
details_url: Optional[str]
|
||||
output_title: Optional[str]
|
||||
output_summary: Optional[str]
|
||||
output_text: Optional[str]
|
||||
annotations: list[dict]
|
||||
|
||||
class MatchedLine(TypedDict):
|
||||
check_name: str
|
||||
line_number: int
|
||||
line: str
|
||||
context: list[str]
|
||||
|
||||
check_run: CheckRunItem = SchemaField(
|
||||
title="Check Run",
|
||||
description="Individual CI check run with details",
|
||||
)
|
||||
check_runs: list[CheckRunItem] = SchemaField(
|
||||
description="List of all CI check runs"
|
||||
)
|
||||
matched_line: MatchedLine = SchemaField(
|
||||
title="Matched Line",
|
||||
description="Line matching the search pattern with context",
|
||||
)
|
||||
matched_lines: list[MatchedLine] = SchemaField(
|
||||
description="All lines matching the search pattern across all checks"
|
||||
)
|
||||
overall_status: str = SchemaField(
|
||||
description="Overall CI status (pending, success, failure)"
|
||||
)
|
||||
overall_conclusion: str = SchemaField(
|
||||
description="Overall CI conclusion if completed"
|
||||
)
|
||||
total_checks: int = SchemaField(description="Total number of CI checks")
|
||||
passed_checks: int = SchemaField(description="Number of passed checks")
|
||||
failed_checks: int = SchemaField(description="Number of failed checks")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8ad9e103-78f2-4fdb-ba12-3571f2c95e98",
|
||||
description="This block gets CI results for a commit or PR, with optional search for specific errors/warnings in logs.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetCIResultsBlock.Input,
|
||||
output_schema=GithubGetCIResultsBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"target": "abc123def456",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("overall_status", "completed"),
|
||||
("overall_conclusion", "success"),
|
||||
("total_checks", 1),
|
||||
("passed_checks", 1),
|
||||
("failed_checks", 0),
|
||||
(
|
||||
"check_runs",
|
||||
[
|
||||
{
|
||||
"id": 123456,
|
||||
"name": "build",
|
||||
"status": "completed",
|
||||
"conclusion": "success",
|
||||
"started_at": "2024-01-01T00:00:00Z",
|
||||
"completed_at": "2024-01-01T00:05:00Z",
|
||||
"html_url": "https://github.com/owner/repo/runs/123456",
|
||||
"details_url": None,
|
||||
"output_title": "Build passed",
|
||||
"output_summary": "All tests passed",
|
||||
"output_text": "Build log output...",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_ci_results": lambda *args, **kwargs: {
|
||||
"check_runs": [
|
||||
{
|
||||
"id": 123456,
|
||||
"name": "build",
|
||||
"status": "completed",
|
||||
"conclusion": "success",
|
||||
"started_at": "2024-01-01T00:00:00Z",
|
||||
"completed_at": "2024-01-01T00:05:00Z",
|
||||
"html_url": "https://github.com/owner/repo/runs/123456",
|
||||
"details_url": None,
|
||||
"output_title": "Build passed",
|
||||
"output_summary": "All tests passed",
|
||||
"output_text": "Build log output...",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
"total_count": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_commit_sha(api, repo: str, target: str | int) -> str:
|
||||
"""Get commit SHA from either a commit SHA or PR URL."""
|
||||
# If it's already a SHA, return it
|
||||
|
||||
if isinstance(target, str):
|
||||
if re.match(r"^[0-9a-f]{6,40}$", target, re.IGNORECASE):
|
||||
return target
|
||||
|
||||
# If it's a PR URL, get the head SHA
|
||||
if isinstance(target, int):
|
||||
pr_url = f"https://api.github.com/repos/{repo}/pulls/{target}"
|
||||
response = await api.get(pr_url)
|
||||
pr_data = response.json()
|
||||
return pr_data["head"]["sha"]
|
||||
|
||||
raise ValueError("Target must be a commit SHA or PR URL")
|
||||
|
||||
@staticmethod
|
||||
async def search_in_logs(
|
||||
check_runs: list,
|
||||
pattern: str,
|
||||
) -> list[Output.MatchedLine]:
|
||||
"""Search for pattern in check run logs."""
|
||||
if not pattern:
|
||||
return []
|
||||
|
||||
matched_lines = []
|
||||
regex = re.compile(pattern, re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
for check in check_runs:
|
||||
output_text = check.get("output_text", "") or ""
|
||||
if not output_text:
|
||||
continue
|
||||
|
||||
lines = output_text.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
if regex.search(line):
|
||||
# Get context (2 lines before and after)
|
||||
start = max(0, i - 2)
|
||||
end = min(len(lines), i + 3)
|
||||
context = lines[start:end]
|
||||
|
||||
matched_lines.append(
|
||||
{
|
||||
"check_name": check["name"],
|
||||
"line_number": i + 1,
|
||||
"line": line,
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
return matched_lines
|
||||
|
||||
@staticmethod
|
||||
async def get_ci_results(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
target: str | int,
|
||||
search_pattern: Optional[str] = None,
|
||||
check_name_filter: Optional[str] = None,
|
||||
) -> dict:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# Get the commit SHA
|
||||
commit_sha = await GithubGetCIResultsBlock.get_commit_sha(api, repo, target)
|
||||
|
||||
# Get check runs for the commit
|
||||
check_runs_url = (
|
||||
f"https://api.github.com/repos/{repo}/commits/{commit_sha}/check-runs"
|
||||
)
|
||||
|
||||
# Get all pages of check runs
|
||||
all_check_runs = []
|
||||
page = 1
|
||||
per_page = 100
|
||||
|
||||
while True:
|
||||
response = await api.get(
|
||||
check_runs_url, params={"per_page": per_page, "page": page}
|
||||
)
|
||||
data = response.json()
|
||||
|
||||
check_runs = data.get("check_runs", [])
|
||||
all_check_runs.extend(check_runs)
|
||||
|
||||
if len(check_runs) < per_page:
|
||||
break
|
||||
page += 1
|
||||
|
||||
# Filter by check name if specified
|
||||
if check_name_filter:
|
||||
import fnmatch
|
||||
|
||||
filtered_runs = []
|
||||
for run in all_check_runs:
|
||||
if fnmatch.fnmatch(run["name"].lower(), check_name_filter.lower()):
|
||||
filtered_runs.append(run)
|
||||
all_check_runs = filtered_runs
|
||||
|
||||
# Get check run details with logs
|
||||
detailed_runs = []
|
||||
for run in all_check_runs:
|
||||
# Get detailed output including logs
|
||||
if run.get("output", {}).get("text"):
|
||||
# Already has output
|
||||
detailed_run = {
|
||||
"id": run["id"],
|
||||
"name": run["name"],
|
||||
"status": run["status"],
|
||||
"conclusion": run.get("conclusion"),
|
||||
"started_at": run.get("started_at"),
|
||||
"completed_at": run.get("completed_at"),
|
||||
"html_url": run["html_url"],
|
||||
"details_url": run.get("details_url"),
|
||||
"output_title": run.get("output", {}).get("title"),
|
||||
"output_summary": run.get("output", {}).get("summary"),
|
||||
"output_text": run.get("output", {}).get("text"),
|
||||
"annotations": [],
|
||||
}
|
||||
else:
|
||||
# Try to get logs from the check run
|
||||
detailed_run = {
|
||||
"id": run["id"],
|
||||
"name": run["name"],
|
||||
"status": run["status"],
|
||||
"conclusion": run.get("conclusion"),
|
||||
"started_at": run.get("started_at"),
|
||||
"completed_at": run.get("completed_at"),
|
||||
"html_url": run["html_url"],
|
||||
"details_url": run.get("details_url"),
|
||||
"output_title": run.get("output", {}).get("title"),
|
||||
"output_summary": run.get("output", {}).get("summary"),
|
||||
"output_text": None,
|
||||
"annotations": [],
|
||||
}
|
||||
|
||||
# Get annotations if available
|
||||
if run.get("output", {}).get("annotations_count", 0) > 0:
|
||||
annotations_url = f"https://api.github.com/repos/{repo}/check-runs/{run['id']}/annotations"
|
||||
try:
|
||||
ann_response = await api.get(annotations_url)
|
||||
detailed_run["annotations"] = ann_response.json()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
detailed_runs.append(detailed_run)
|
||||
|
||||
return {
|
||||
"check_runs": detailed_runs,
|
||||
"total_count": len(detailed_runs),
|
||||
}
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
try:
|
||||
target = int(input_data.target)
|
||||
except ValueError:
|
||||
target = input_data.target
|
||||
|
||||
result = await self.get_ci_results(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
target,
|
||||
input_data.search_pattern,
|
||||
input_data.check_name_filter,
|
||||
)
|
||||
|
||||
check_runs = result["check_runs"]
|
||||
|
||||
# Calculate overall status
|
||||
if not check_runs:
|
||||
yield "overall_status", "no_checks"
|
||||
yield "overall_conclusion", "no_checks"
|
||||
else:
|
||||
all_completed = all(run["status"] == "completed" for run in check_runs)
|
||||
if all_completed:
|
||||
yield "overall_status", "completed"
|
||||
# Determine overall conclusion
|
||||
has_failure = any(
|
||||
run["conclusion"] in ["failure", "timed_out", "action_required"]
|
||||
for run in check_runs
|
||||
)
|
||||
if has_failure:
|
||||
yield "overall_conclusion", "failure"
|
||||
else:
|
||||
yield "overall_conclusion", "success"
|
||||
else:
|
||||
yield "overall_status", "pending"
|
||||
yield "overall_conclusion", "pending"
|
||||
|
||||
# Count checks
|
||||
total = len(check_runs)
|
||||
passed = sum(1 for run in check_runs if run.get("conclusion") == "success")
|
||||
failed = sum(
|
||||
1 for run in check_runs if run.get("conclusion") in ["failure", "timed_out"]
|
||||
)
|
||||
|
||||
yield "total_checks", total
|
||||
yield "passed_checks", passed
|
||||
yield "failed_checks", failed
|
||||
|
||||
# Output check runs
|
||||
yield "check_runs", check_runs
|
||||
|
||||
# Search for patterns if specified
|
||||
if input_data.search_pattern:
|
||||
matched_lines = await self.search_in_logs(
|
||||
check_runs, input_data.search_pattern
|
||||
)
|
||||
if matched_lines:
|
||||
yield "matched_lines", matched_lines
|
||||
840
autogpt_platform/backend/backend/blocks/github/reviews.py
Normal file
840
autogpt_platform/backend/backend/blocks/github/reviews.py
Normal file
@@ -0,0 +1,840 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReviewEvent(Enum):
|
||||
COMMENT = "COMMENT"
|
||||
APPROVE = "APPROVE"
|
||||
REQUEST_CHANGES = "REQUEST_CHANGES"
|
||||
|
||||
|
||||
class GithubCreatePRReviewBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class ReviewComment(TypedDict, total=False):
|
||||
path: str
|
||||
position: Optional[int]
|
||||
body: str
|
||||
line: Optional[int] # Will be used as position if position not provided
|
||||
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="Body of the review comment",
|
||||
placeholder="Enter your review comment",
|
||||
)
|
||||
event: ReviewEvent = SchemaField(
|
||||
description="The review action to perform",
|
||||
default=ReviewEvent.COMMENT,
|
||||
)
|
||||
create_as_draft: bool = SchemaField(
|
||||
description="Create the review as a draft (pending) or post it immediately",
|
||||
default=False,
|
||||
advanced=False,
|
||||
)
|
||||
comments: Optional[List[ReviewComment]] = SchemaField(
|
||||
description="Optional inline comments to add to specific files/lines. Note: Only path, body, and position are supported. Position is line number in diff from first @@ hunk.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
review_id: int = SchemaField(description="ID of the created review")
|
||||
state: str = SchemaField(
|
||||
description="State of the review (e.g., PENDING, COMMENTED, APPROVED, CHANGES_REQUESTED)"
|
||||
)
|
||||
html_url: str = SchemaField(description="URL of the created review")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the review creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="84754b30-97d2-4c37-a3b8-eb39f268275b",
|
||||
description="This block creates a review on a GitHub pull request with optional inline comments. You can create it as a draft or post immediately. Note: For inline comments, 'position' should be the line number in the diff (starting from the first @@ hunk header).",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreatePRReviewBlock.Input,
|
||||
output_schema=GithubCreatePRReviewBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"body": "This looks good to me!",
|
||||
"event": "APPROVE",
|
||||
"create_as_draft": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("review_id", 123456),
|
||||
("state", "APPROVED"),
|
||||
(
|
||||
"html_url",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_review": lambda *args, **kwargs: (
|
||||
123456,
|
||||
"APPROVED",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_review(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
body: str,
|
||||
event: ReviewEvent,
|
||||
create_as_draft: bool,
|
||||
comments: Optional[List[Input.ReviewComment]] = None,
|
||||
) -> tuple[int, str, str]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# GitHub API endpoint for creating reviews
|
||||
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
|
||||
|
||||
# Get commit_id if we have comments
|
||||
commit_id = None
|
||||
if comments:
|
||||
# Get PR details to get the head commit for inline comments
|
||||
pr_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}"
|
||||
pr_response = await api.get(pr_url)
|
||||
pr_data = pr_response.json()
|
||||
commit_id = pr_data["head"]["sha"]
|
||||
|
||||
# Prepare the request data
|
||||
# If create_as_draft is True, omit the event field (creates a PENDING review)
|
||||
# Otherwise, use the actual event value which will auto-submit the review
|
||||
data: dict[str, Any] = {"body": body}
|
||||
|
||||
# Add commit_id if we have it
|
||||
if commit_id:
|
||||
data["commit_id"] = commit_id
|
||||
|
||||
# Add comments if provided
|
||||
if comments:
|
||||
# Process comments to ensure they have the required fields
|
||||
processed_comments = []
|
||||
for comment in comments:
|
||||
comment_data: dict = {
|
||||
"path": comment.get("path", ""),
|
||||
"body": comment.get("body", ""),
|
||||
}
|
||||
# Add position or line
|
||||
# Note: For review comments, only position is supported (not line/side)
|
||||
if "position" in comment and comment.get("position") is not None:
|
||||
comment_data["position"] = comment.get("position")
|
||||
elif "line" in comment and comment.get("line") is not None:
|
||||
# Note: Using line as position - may not work correctly
|
||||
# Position should be calculated from the diff
|
||||
comment_data["position"] = comment.get("line")
|
||||
|
||||
# Note: side, start_line, and start_side are NOT supported for review comments
|
||||
# They are only for standalone PR comments
|
||||
|
||||
processed_comments.append(comment_data)
|
||||
|
||||
data["comments"] = processed_comments
|
||||
|
||||
if not create_as_draft:
|
||||
# Only add event field if not creating a draft
|
||||
data["event"] = event.value
|
||||
|
||||
# Create the review
|
||||
response = await api.post(reviews_url, json=data)
|
||||
review_data = response.json()
|
||||
|
||||
return review_data["id"], review_data["state"], review_data["html_url"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
review_id, state, html_url = await self.create_review(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.body,
|
||||
input_data.event,
|
||||
input_data.create_as_draft,
|
||||
input_data.comments,
|
||||
)
|
||||
yield "review_id", review_id
|
||||
yield "state", state
|
||||
yield "html_url", html_url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubListPRReviewsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class ReviewItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
state: str
|
||||
body: str
|
||||
html_url: str
|
||||
|
||||
review: ReviewItem = SchemaField(
|
||||
title="Review",
|
||||
description="Individual review with details",
|
||||
)
|
||||
reviews: list[ReviewItem] = SchemaField(
|
||||
description="List of all reviews on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing reviews failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f79bc6eb-33c0-4099-9c0f-d664ae1ba4d0",
|
||||
description="This block lists all reviews for a specified GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListPRReviewsBlock.Input,
|
||||
output_schema=GithubListPRReviewsBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"reviews",
|
||||
[
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"state": "APPROVED",
|
||||
"body": "Looks good!",
|
||||
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"review",
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"state": "APPROVED",
|
||||
"body": "Looks good!",
|
||||
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_reviews": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"state": "APPROVED",
|
||||
"body": "Looks good!",
|
||||
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_reviews(
|
||||
credentials: GithubCredentials, repo: str, pr_number: int
|
||||
) -> list[Output.ReviewItem]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# GitHub API endpoint for listing reviews
|
||||
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
|
||||
|
||||
response = await api.get(reviews_url)
|
||||
data = response.json()
|
||||
|
||||
reviews: list[GithubListPRReviewsBlock.Output.ReviewItem] = [
|
||||
{
|
||||
"id": review["id"],
|
||||
"user": review["user"]["login"],
|
||||
"state": review["state"],
|
||||
"body": review.get("body", ""),
|
||||
"html_url": review["html_url"],
|
||||
}
|
||||
for review in data
|
||||
]
|
||||
return reviews
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
reviews = await self.list_reviews(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
)
|
||||
yield "reviews", reviews
|
||||
for review in reviews:
|
||||
yield "review", review
|
||||
|
||||
|
||||
class GithubSubmitPendingReviewBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
review_id: int = SchemaField(
|
||||
description="ID of the pending review to submit",
|
||||
placeholder="123456",
|
||||
)
|
||||
event: ReviewEvent = SchemaField(
|
||||
description="The review action to perform when submitting",
|
||||
default=ReviewEvent.COMMENT,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
state: str = SchemaField(description="State of the submitted review")
|
||||
html_url: str = SchemaField(description="URL of the submitted review")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the review submission failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e468217-7ca0-4201-9553-36e93eb9357a",
|
||||
description="This block submits a pending (draft) review on a GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubSubmitPendingReviewBlock.Input,
|
||||
output_schema=GithubSubmitPendingReviewBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"review_id": 123456,
|
||||
"event": "APPROVE",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("state", "APPROVED"),
|
||||
(
|
||||
"html_url",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"submit_review": lambda *args, **kwargs: (
|
||||
"APPROVED",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def submit_review(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
review_id: int,
|
||||
event: ReviewEvent,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# GitHub API endpoint for submitting a review
|
||||
submit_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/events"
|
||||
|
||||
data = {"event": event.value}
|
||||
|
||||
response = await api.post(submit_url, json=data)
|
||||
review_data = response.json()
|
||||
|
||||
return review_data["state"], review_data["html_url"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
state, html_url = await self.submit_review(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.review_id,
|
||||
input_data.event,
|
||||
)
|
||||
yield "state", state
|
||||
yield "html_url", html_url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubResolveReviewDiscussionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
comment_id: int = SchemaField(
|
||||
description="ID of the review comment to resolve/unresolve",
|
||||
placeholder="123456",
|
||||
)
|
||||
resolve: bool = SchemaField(
|
||||
description="Whether to resolve (true) or unresolve (false) the discussion",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b4b8a38c-95ae-4c91-9ef8-c2cffaf2b5d1",
|
||||
description="This block resolves or unresolves a review discussion thread on a GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubResolveReviewDiscussionBlock.Input,
|
||||
output_schema=GithubResolveReviewDiscussionBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"comment_id": 123456,
|
||||
"resolve": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"resolve_discussion": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def resolve_discussion(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
comment_id: int,
|
||||
resolve: bool,
|
||||
) -> bool:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# Extract owner and repo name
|
||||
parts = repo.split("/")
|
||||
owner = parts[0]
|
||||
repo_name = parts[1]
|
||||
|
||||
# GitHub GraphQL API is needed for resolving/unresolving discussions
|
||||
# First, we need to get the node ID of the comment
|
||||
graphql_url = "https://api.github.com/graphql"
|
||||
|
||||
# Query to get the review comment node ID
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $number: Int!) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
pullRequest(number: $number) {
|
||||
reviewThreads(first: 100) {
|
||||
nodes {
|
||||
comments(first: 100) {
|
||||
nodes {
|
||||
databaseId
|
||||
id
|
||||
}
|
||||
}
|
||||
id
|
||||
isResolved
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
variables = {"owner": owner, "repo": repo_name, "number": pr_number}
|
||||
|
||||
response = await api.post(
|
||||
graphql_url, json={"query": query, "variables": variables}
|
||||
)
|
||||
data = response.json()
|
||||
|
||||
# Find the thread containing our comment
|
||||
thread_id = None
|
||||
for thread in data["data"]["repository"]["pullRequest"]["reviewThreads"][
|
||||
"nodes"
|
||||
]:
|
||||
for comment in thread["comments"]["nodes"]:
|
||||
if comment["databaseId"] == comment_id:
|
||||
thread_id = thread["id"]
|
||||
break
|
||||
if thread_id:
|
||||
break
|
||||
|
||||
if not thread_id:
|
||||
raise ValueError(f"Comment {comment_id} not found in pull request")
|
||||
|
||||
# Now resolve or unresolve the thread
|
||||
# GitHub's GraphQL API has separate mutations for resolve and unresolve
|
||||
if resolve:
|
||||
mutation = """
|
||||
mutation($threadId: ID!) {
|
||||
resolveReviewThread(input: {threadId: $threadId}) {
|
||||
thread {
|
||||
isResolved
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
else:
|
||||
mutation = """
|
||||
mutation($threadId: ID!) {
|
||||
unresolveReviewThread(input: {threadId: $threadId}) {
|
||||
thread {
|
||||
isResolved
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
mutation_variables = {"threadId": thread_id}
|
||||
|
||||
response = await api.post(
|
||||
graphql_url, json={"query": mutation, "variables": mutation_variables}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
if "errors" in result:
|
||||
raise Exception(f"GraphQL error: {result['errors']}")
|
||||
|
||||
return True
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = await self.resolve_discussion(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.comment_id,
|
||||
input_data.resolve,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "success", False
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubGetPRReviewCommentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
review_id: Optional[int] = SchemaField(
|
||||
description="ID of a specific review to get comments from (optional)",
|
||||
placeholder="123456",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CommentItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
body: str
|
||||
path: str
|
||||
line: int
|
||||
side: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
in_reply_to_id: Optional[int]
|
||||
html_url: str
|
||||
|
||||
comment: CommentItem = SchemaField(
|
||||
title="Comment",
|
||||
description="Individual review comment with details",
|
||||
)
|
||||
comments: list[CommentItem] = SchemaField(
|
||||
description="List of all review comments on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting comments failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1d34db7f-10c1-45c1-9d43-749f743c8bd4",
|
||||
description="This block gets all review comments from a GitHub pull request or from a specific review.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetPRReviewCommentsBlock.Input,
|
||||
output_schema=GithubGetPRReviewCommentsBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"comments",
|
||||
[
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"body": "This needs improvement",
|
||||
"path": "src/main.py",
|
||||
"line": 42,
|
||||
"side": "RIGHT",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"in_reply_to_id": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"comment",
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"body": "This needs improvement",
|
||||
"path": "src/main.py",
|
||||
"line": 42,
|
||||
"side": "RIGHT",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"in_reply_to_id": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_comments": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"body": "This needs improvement",
|
||||
"path": "src/main.py",
|
||||
"line": 42,
|
||||
"side": "RIGHT",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"in_reply_to_id": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_comments(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
review_id: Optional[int] = None,
|
||||
) -> list[Output.CommentItem]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# Determine the endpoint based on whether we want comments from a specific review
|
||||
if review_id:
|
||||
# Get comments from a specific review
|
||||
comments_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/comments"
|
||||
else:
|
||||
# Get all review comments on the PR
|
||||
comments_url = (
|
||||
f"https://api.github.com/repos/{repo}/pulls/{pr_number}/comments"
|
||||
)
|
||||
|
||||
response = await api.get(comments_url)
|
||||
data = response.json()
|
||||
|
||||
comments: list[GithubGetPRReviewCommentsBlock.Output.CommentItem] = [
|
||||
{
|
||||
"id": comment["id"],
|
||||
"user": comment["user"]["login"],
|
||||
"body": comment["body"],
|
||||
"path": comment.get("path", ""),
|
||||
"line": comment.get("line", 0),
|
||||
"side": comment.get("side", ""),
|
||||
"created_at": comment["created_at"],
|
||||
"updated_at": comment["updated_at"],
|
||||
"in_reply_to_id": comment.get("in_reply_to_id"),
|
||||
"html_url": comment["html_url"],
|
||||
}
|
||||
for comment in data
|
||||
]
|
||||
return comments
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
comments = await self.get_comments(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.review_id,
|
||||
)
|
||||
yield "comments", comments
|
||||
for comment in comments:
|
||||
yield "comment", comment
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateCommentObjectBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
path: str = SchemaField(
|
||||
description="The file path to comment on",
|
||||
placeholder="src/main.py",
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="The comment text",
|
||||
placeholder="Please fix this issue",
|
||||
)
|
||||
position: Optional[int] = SchemaField(
|
||||
description="Position in the diff (line number from first @@ hunk). Use this OR line.",
|
||||
placeholder="6",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
line: Optional[int] = SchemaField(
|
||||
description="Line number in the file (will be used as position if position not provided)",
|
||||
placeholder="42",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
side: Optional[str] = SchemaField(
|
||||
description="Side of the diff to comment on (NOTE: Only for standalone comments, not review comments)",
|
||||
default="RIGHT",
|
||||
advanced=True,
|
||||
)
|
||||
start_line: Optional[int] = SchemaField(
|
||||
description="Start line for multi-line comments (NOTE: Only for standalone comments, not review comments)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
start_side: Optional[str] = SchemaField(
|
||||
description="Side for the start of multi-line comments (NOTE: Only for standalone comments, not review comments)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
comment_object: dict = SchemaField(
|
||||
description="The comment object formatted for GitHub API"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b7d5e4f2-8c3a-4e6b-9f1d-7a8b9c5e4d3f",
|
||||
description="Creates a comment object for use with GitHub blocks. Note: For review comments, only path, body, and position are used. Side fields are only for standalone PR comments.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateCommentObjectBlock.Input,
|
||||
output_schema=GithubCreateCommentObjectBlock.Output,
|
||||
test_input={
|
||||
"path": "src/main.py",
|
||||
"body": "Please fix this issue",
|
||||
"position": 6,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"comment_object",
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"body": "Please fix this issue",
|
||||
"position": 6,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Build the comment object
|
||||
comment_obj: dict = {
|
||||
"path": input_data.path,
|
||||
"body": input_data.body,
|
||||
}
|
||||
|
||||
# Add position or line
|
||||
if input_data.position is not None:
|
||||
comment_obj["position"] = input_data.position
|
||||
elif input_data.line is not None:
|
||||
# Note: line will be used as position, which may not be accurate
|
||||
# Position should be calculated from the diff
|
||||
comment_obj["position"] = input_data.line
|
||||
|
||||
# Add optional fields only if they differ from defaults or are explicitly provided
|
||||
if input_data.side and input_data.side != "RIGHT":
|
||||
comment_obj["side"] = input_data.side
|
||||
if input_data.start_line is not None:
|
||||
comment_obj["start_line"] = input_data.start_line
|
||||
if input_data.start_side:
|
||||
comment_obj["start_side"] = input_data.start_side
|
||||
|
||||
yield "comment_object", comment_obj
|
||||
@@ -21,6 +21,8 @@ from ._auth import (
|
||||
GoogleCredentialsInput,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
"""Structured representation of a Google Calendar event."""
|
||||
@@ -221,8 +223,8 @@ class GoogleCalendarReadEventsBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
client_id=settings.secrets.google_client_id,
|
||||
client_secret=settings.secrets.google_client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("calendar", "v3", credentials=creds)
|
||||
@@ -569,8 +571,8 @@ class GoogleCalendarCreateEventBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
client_id=settings.secrets.google_client_id,
|
||||
client_secret=settings.secrets.google_client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("calendar", "v3", credentials=creds)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -37,6 +37,7 @@ LLMProviderName = Literal[
|
||||
ProviderName.OPENAI,
|
||||
ProviderName.OPEN_ROUTER,
|
||||
ProviderName.LLAMA_API,
|
||||
ProviderName.V0,
|
||||
]
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
@@ -81,6 +82,11 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
@@ -88,6 +94,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
@@ -115,6 +122,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
@@ -147,6 +156,10 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
|
||||
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
|
||||
# v0 by Vercel models
|
||||
V0_1_5_MD = "v0-1.5-md"
|
||||
V0_1_5_LG = "v0-1.5-lg"
|
||||
V0_1_0_MD = "v0-1.0-md"
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
@@ -171,6 +184,11 @@ MODEL_METADATA = {
|
||||
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT41_MINI: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
@@ -182,6 +200,9 @@ MODEL_METADATA = {
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-4-opus-20250514
|
||||
@@ -246,6 +267,8 @@ MODEL_METADATA = {
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
||||
"open_router", 12288, 12288
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata("open_router", 131072, 32768),
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 300000, 5120),
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 128000, 5120),
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
|
||||
@@ -262,6 +285,10 @@ MODEL_METADATA = {
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata("llama_api", 128000, 4028),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata("llama_api", 128000, 4028),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata("llama_api", 128000, 4028),
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000),
|
||||
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000),
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -475,6 +502,7 @@ async def llm_call(
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
@@ -657,7 +685,11 @@ async def llm_call(
|
||||
client = openai.OpenAI(
|
||||
base_url="https://api.aimlapi.com/v2",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
default_headers={"X-Project": "AutoGPT"},
|
||||
default_headers={
|
||||
"X-Project": "AutoGPT",
|
||||
"X-Title": "AutoGPT",
|
||||
"HTTP-Referer": "https://github.com/Significant-Gravitas/AutoGPT",
|
||||
},
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
@@ -677,6 +709,42 @@ async def llm_call(
|
||||
),
|
||||
reasoning=None,
|
||||
)
|
||||
elif provider == "v0":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://api.v0.dev/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
response_format = None
|
||||
if json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
response=response.choices[0].message.content or "",
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
|
||||
@@ -1,22 +1,13 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
StorageScope = Literal["within_agent", "across_agents"]
|
||||
|
||||
|
||||
@@ -88,7 +79,7 @@ class PersistInformationBlock(Block):
|
||||
async def _store_data(
|
||||
self, user_id: str, node_exec_id: str, key: str, data: Any
|
||||
) -> Any | None:
|
||||
return await get_database_manager_client().set_execution_kv_data(
|
||||
return await get_database_manager_async_client().set_execution_kv_data(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
key=key,
|
||||
@@ -149,7 +140,7 @@ class RetrieveInformationBlock(Block):
|
||||
yield "value", input_data.default_value
|
||||
|
||||
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
|
||||
return await get_database_manager_client().get_execution_kv_data(
|
||||
return await get_database_manager_async_client().get_execution_kv_data(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
)
|
||||
|
||||
@@ -3,8 +3,7 @@ from typing import List
|
||||
|
||||
from backend.data.block import BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util import settings
|
||||
from backend.util.settings import BehaveAs
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -16,6 +15,8 @@ from ._api import (
|
||||
)
|
||||
from .base import Slant3DBlockBase
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"""Block for creating new orders"""
|
||||
@@ -280,7 +281,7 @@ class Slant3DGetOrdersBlock(Slant3DBlockBase):
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
# This block is disabled for cloud hosted because it allows access to all orders for the account
|
||||
disabled=settings.Settings().config.behave_as == BehaveAs.CLOUD,
|
||||
disabled=settings.config.behave_as == BehaveAs.CLOUD,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
|
||||
@@ -9,8 +9,7 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import settings
|
||||
from backend.util.settings import AppEnvironment, BehaveAs
|
||||
from backend.util.settings import AppEnvironment, BehaveAs, Settings
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -19,6 +18,8 @@ from ._api import (
|
||||
Slant3DCredentialsInput,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class Slant3DTriggerBase:
|
||||
"""Base class for Slant3D webhook triggers"""
|
||||
@@ -76,8 +77,8 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
|
||||
),
|
||||
# All webhooks are currently subscribed to for all orders. This works for self hosted, but not for cloud hosted prod
|
||||
disabled=(
|
||||
settings.Settings().config.behave_as == BehaveAs.CLOUD
|
||||
and settings.Settings().config.app_env != AppEnvironment.LOCAL
|
||||
settings.config.behave_as == BehaveAs.CLOUD
|
||||
and settings.config.app_env != AppEnvironment.LOCAL
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
|
||||
@@ -3,8 +3,6 @@ import re
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data.block import (
|
||||
@@ -17,6 +15,7 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import Link, Node
|
||||
@@ -24,14 +23,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool request.
|
||||
@@ -300,9 +291,32 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
for link in links:
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
)
|
||||
|
||||
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
|
||||
# These are fields that get merged by the executor into their base field
|
||||
if (
|
||||
"_#_" in link.sink_name
|
||||
or "_$_" in link.sink_name
|
||||
or "_@_" in link.sink_name
|
||||
):
|
||||
# For dynamic fields, provide a generic string schema
|
||||
# The executor will handle merging these into the appropriate structure
|
||||
properties[sink_name] = {
|
||||
"type": "string",
|
||||
"description": f"Dynamic value for {link.sink_name}",
|
||||
}
|
||||
else:
|
||||
# For regular fields, use the block's schema
|
||||
try:
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# If the field doesn't exist in the schema, provide a generic schema
|
||||
properties[sink_name] = {
|
||||
"type": "string",
|
||||
"description": f"Value for {link.sink_name}",
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
**block.input_schema.jsonschema(),
|
||||
@@ -333,7 +347,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
if not graph_id or not graph_version:
|
||||
raise ValueError("Graph ID or Graph Version not found in sink node.")
|
||||
|
||||
db_client = get_database_manager_client()
|
||||
db_client = get_database_manager_async_client()
|
||||
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
@@ -393,7 +407,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
db_client = get_database_manager_async_client()
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in await db_client.get_connected_output_nodes(node_id)
|
||||
@@ -487,10 +501,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
}
|
||||
)
|
||||
prompt.extend(tool_output)
|
||||
if input_data.multiple_tool_calls:
|
||||
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
|
||||
else:
|
||||
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
@@ -529,15 +539,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
)
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
yield "conversations", prompt
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
return
|
||||
@@ -571,3 +572,12 @@ class SmartDecisionMakerBlock(Block):
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
yield "conversations", prompt
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
from backend.data.model import ProviderName
|
||||
from backend.data.model import ProviderName, User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic dictionary fields (_#_) for any block"""
|
||||
|
||||
# Create a mock node for CreateDictionaryBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_name",
|
||||
sink_name="values_#_name", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_age",
|
||||
sink_name="values_#_age", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_create_dict_~_city",
|
||||
sink_name="values_#_city", # Dynamic dict field
|
||||
sink_id="dict_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
# Verify the signature was created successfully
|
||||
assert signature["type"] == "function"
|
||||
assert "parameters" in signature["function"]
|
||||
assert "properties" in signature["function"]["parameters"]
|
||||
|
||||
# Check that dynamic fields are handled
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 3 # Should have all three fields
|
||||
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
assert "type" in prop_value
|
||||
assert prop_value["type"] == "string" # Dynamic fields get string type
|
||||
assert "description" in prop_value
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
"""Test Smart Decision Maker can handle dynamic list fields (_$_) for any block"""
|
||||
|
||||
# Create a mock node for AddToListBlock
|
||||
mock_node = Mock()
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_add_to_list_~_0",
|
||||
sink_name="entries_$_0", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
Mock(
|
||||
source_name="tools_^_add_to_list_~_1",
|
||||
sink_name="entries_$_1", # Dynamic list field
|
||||
sink_id="list_node_id",
|
||||
source_id="smart_decision_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
# Generate function signature
|
||||
signature = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, mock_links # type: ignore
|
||||
)
|
||||
|
||||
# Verify dynamic list fields are handled properly
|
||||
assert signature["type"] == "function"
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert len(properties) == 2 # Should have both list items
|
||||
|
||||
# Each dynamic field should have proper schema
|
||||
for prop_value in properties.values():
|
||||
assert prop_value["type"] == "string"
|
||||
assert "Dynamic value for" in prop_value["description"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dict_block_with_dynamic_values():
|
||||
"""Test CreateDictionaryBlock processes dynamic values correctly"""
|
||||
|
||||
block = CreateDictionaryBlock()
|
||||
|
||||
# Simulate what happens when executor merges dynamic fields
|
||||
# The executor merges values_#_* fields into the values dict
|
||||
input_data = block.input_schema(
|
||||
values={
|
||||
"existing": "value",
|
||||
"name": "Alice", # This would come from values_#_name
|
||||
"age": 25, # This would come from values_#_age
|
||||
}
|
||||
)
|
||||
|
||||
# Run the block
|
||||
result = {}
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
result[output_name] = output_value
|
||||
|
||||
# Check the result
|
||||
assert "dictionary" in result
|
||||
assert result["dictionary"]["existing"] == "value"
|
||||
assert result["dictionary"]["name"] == "Alice"
|
||||
assert result["dictionary"]["age"] == 25
|
||||
@@ -1,19 +1,78 @@
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Union
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# Shared timezone literal type for all time/date blocks
|
||||
TimezoneLiteral = Literal[
|
||||
"UTC", # UTC±00:00
|
||||
"Pacific/Honolulu", # UTC-10:00
|
||||
"America/Anchorage", # UTC-09:00 (Alaska)
|
||||
"America/Los_Angeles", # UTC-08:00 (Pacific)
|
||||
"America/Denver", # UTC-07:00 (Mountain)
|
||||
"America/Chicago", # UTC-06:00 (Central)
|
||||
"America/New_York", # UTC-05:00 (Eastern)
|
||||
"America/Caracas", # UTC-04:00
|
||||
"America/Sao_Paulo", # UTC-03:00
|
||||
"America/St_Johns", # UTC-02:30 (Newfoundland)
|
||||
"Atlantic/South_Georgia", # UTC-02:00
|
||||
"Atlantic/Azores", # UTC-01:00
|
||||
"Europe/London", # UTC+00:00 (GMT/BST)
|
||||
"Europe/Paris", # UTC+01:00 (CET)
|
||||
"Europe/Athens", # UTC+02:00 (EET)
|
||||
"Europe/Moscow", # UTC+03:00
|
||||
"Asia/Tehran", # UTC+03:30 (Iran)
|
||||
"Asia/Dubai", # UTC+04:00
|
||||
"Asia/Kabul", # UTC+04:30 (Afghanistan)
|
||||
"Asia/Karachi", # UTC+05:00 (Pakistan)
|
||||
"Asia/Kolkata", # UTC+05:30 (India)
|
||||
"Asia/Kathmandu", # UTC+05:45 (Nepal)
|
||||
"Asia/Dhaka", # UTC+06:00 (Bangladesh)
|
||||
"Asia/Yangon", # UTC+06:30 (Myanmar)
|
||||
"Asia/Bangkok", # UTC+07:00
|
||||
"Asia/Shanghai", # UTC+08:00 (China)
|
||||
"Australia/Eucla", # UTC+08:45
|
||||
"Asia/Tokyo", # UTC+09:00 (Japan)
|
||||
"Australia/Adelaide", # UTC+09:30
|
||||
"Australia/Sydney", # UTC+10:00
|
||||
"Australia/Lord_Howe", # UTC+10:30
|
||||
"Pacific/Noumea", # UTC+11:00
|
||||
"Pacific/Auckland", # UTC+12:00 (New Zealand)
|
||||
"Pacific/Chatham", # UTC+12:45
|
||||
"Pacific/Tongatapu", # UTC+13:00
|
||||
"Pacific/Kiritimati", # UTC+14:00
|
||||
"Etc/GMT-12", # UTC+12:00
|
||||
"Etc/GMT+12", # UTC-12:00
|
||||
]
|
||||
|
||||
|
||||
class TimeStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
|
||||
|
||||
class TimeISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
include_microseconds: bool = False
|
||||
|
||||
|
||||
class GetCurrentTimeBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current time"
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Format of the time to output", default="%H:%M:%S"
|
||||
format_type: Union[TimeStrftimeFormat, TimeISO8601Format] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Format type for time output (strftime with custom format or ISO 8601)",
|
||||
default=TimeStrftimeFormat(discriminator="strftime"),
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -30,19 +89,65 @@ class GetCurrentTimeBlock(Block):
|
||||
output_schema=GetCurrentTimeBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello"},
|
||||
{"trigger": "Hello", "format": "%H:%M"},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"format_type": {
|
||||
"discriminator": "strftime",
|
||||
"format": "%H:%M",
|
||||
},
|
||||
},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"format_type": {
|
||||
"discriminator": "iso8601",
|
||||
"timezone": "UTC",
|
||||
"include_microseconds": False,
|
||||
},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("time", lambda _: time.strftime("%H:%M:%S")),
|
||||
("time", lambda _: time.strftime("%H:%M")),
|
||||
(
|
||||
"time",
|
||||
lambda t: "T" in t and ("+" in t or "Z" in t),
|
||||
), # Check for ISO format with timezone
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_time = time.strftime(input_data.format)
|
||||
if isinstance(input_data.format_type, TimeISO8601Format):
|
||||
# ISO 8601 format for time only (extract time portion from full ISO datetime)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Get the full ISO format and extract just the time portion with timezone
|
||||
if input_data.format_type.include_microseconds:
|
||||
full_iso = dt.isoformat()
|
||||
else:
|
||||
full_iso = dt.isoformat(timespec="seconds")
|
||||
|
||||
# Extract time portion (everything after 'T')
|
||||
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
|
||||
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
|
||||
else: # TimeStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_time = dt.strftime(input_data.format_type.format)
|
||||
yield "time", current_time
|
||||
|
||||
|
||||
class DateStrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
|
||||
|
||||
class DateISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
|
||||
|
||||
class GetCurrentDateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str = SchemaField(
|
||||
@@ -53,8 +158,10 @@ class GetCurrentDateBlock(Block):
|
||||
description="Offset in days from the current date",
|
||||
default=0,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Format of the date to output", default="%Y-%m-%d"
|
||||
format_type: Union[DateStrftimeFormat, DateISO8601Format] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Format type for date output (strftime with custom format or ISO 8601)",
|
||||
default=DateStrftimeFormat(discriminator="strftime"),
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -71,7 +178,22 @@ class GetCurrentDateBlock(Block):
|
||||
output_schema=GetCurrentDateBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello", "offset": "7"},
|
||||
{"trigger": "Hello", "offset": "7", "format": "%m/%d/%Y"},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"offset": "7",
|
||||
"format_type": {
|
||||
"discriminator": "strftime",
|
||||
"format": "%m/%d/%Y",
|
||||
},
|
||||
},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"offset": "0",
|
||||
"format_type": {
|
||||
"discriminator": "iso8601",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
@@ -85,6 +207,12 @@ class GetCurrentDateBlock(Block):
|
||||
< timedelta(days=8),
|
||||
# 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
"date",
|
||||
lambda t: len(t) == 10
|
||||
and t[4] == "-"
|
||||
and t[7] == "-", # ISO date format YYYY-MM-DD
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -93,8 +221,31 @@ class GetCurrentDateBlock(Block):
|
||||
offset = int(input_data.offset)
|
||||
except ValueError:
|
||||
offset = 0
|
||||
current_date = datetime.now() - timedelta(days=offset)
|
||||
yield "date", current_date.strftime(input_data.format)
|
||||
|
||||
if isinstance(input_data.format_type, DateISO8601Format):
|
||||
# ISO 8601 format for date only (YYYY-MM-DD)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
# ISO 8601 date format is YYYY-MM-DD
|
||||
date_str = current_date.date().isoformat()
|
||||
else: # DateStrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
current_date = datetime.now(tz=tz) - timedelta(days=offset)
|
||||
date_str = current_date.strftime(input_data.format_type.format)
|
||||
|
||||
yield "date", date_str
|
||||
|
||||
|
||||
class StrftimeFormat(BaseModel):
|
||||
discriminator: Literal["strftime"]
|
||||
format: str = "%Y-%m-%d %H:%M:%S"
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
|
||||
|
||||
class ISO8601Format(BaseModel):
|
||||
discriminator: Literal["iso8601"]
|
||||
timezone: TimezoneLiteral = "UTC"
|
||||
include_microseconds: bool = False
|
||||
|
||||
|
||||
class GetCurrentDateAndTimeBlock(Block):
|
||||
@@ -102,9 +253,10 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current date and time"
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Format of the date and time to output",
|
||||
default="%Y-%m-%d %H:%M:%S",
|
||||
format_type: Union[StrftimeFormat, ISO8601Format] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Format type for date and time output (strftime with custom format or ISO 8601/RFC 3339)",
|
||||
default=StrftimeFormat(discriminator="strftime"),
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -121,20 +273,63 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
output_schema=GetCurrentDateAndTimeBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello"},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"format_type": {
|
||||
"discriminator": "strftime",
|
||||
"format": "%Y/%m/%d",
|
||||
},
|
||||
},
|
||||
{
|
||||
"trigger": "Hello",
|
||||
"format_type": {
|
||||
"discriminator": "iso8601",
|
||||
"timezone": "UTC",
|
||||
"include_microseconds": False,
|
||||
},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"date_time",
|
||||
lambda t: abs(
|
||||
datetime.now() - datetime.strptime(t, "%Y-%m-%d %H:%M:%S")
|
||||
datetime.now(tz=ZoneInfo("UTC"))
|
||||
- datetime.strptime(t + "+00:00", "%Y-%m-%d %H:%M:%S%z")
|
||||
)
|
||||
< timedelta(seconds=10), # 10 seconds error margin.
|
||||
),
|
||||
(
|
||||
"date_time",
|
||||
lambda t: abs(
|
||||
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
|
||||
)
|
||||
< timedelta(days=1), # Date format only, no time component
|
||||
),
|
||||
(
|
||||
"date_time",
|
||||
lambda t: abs(
|
||||
datetime.now(tz=ZoneInfo("UTC")) - datetime.fromisoformat(t)
|
||||
)
|
||||
< timedelta(seconds=10), # 10 seconds error margin for ISO format.
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_date_time = time.strftime(input_data.format)
|
||||
if isinstance(input_data.format_type, ISO8601Format):
|
||||
# ISO 8601 format with specified timezone (also RFC3339-compliant)
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
|
||||
# Format with or without microseconds
|
||||
if input_data.format_type.include_microseconds:
|
||||
current_date_time = dt.isoformat()
|
||||
else:
|
||||
current_date_time = dt.isoformat(timespec="seconds")
|
||||
else: # StrftimeFormat
|
||||
tz = ZoneInfo(input_data.format_type.timezone)
|
||||
dt = datetime.now(tz=tz)
|
||||
current_date_time = dt.strftime(input_data.format_type.format)
|
||||
yield "date_time", current_date_time
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,12 @@ from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.apollo.person import GetPersonDetailBlock
|
||||
from backend.blocks.enrichlayer.linkedin import (
|
||||
GetLinkedinProfileBlock,
|
||||
GetLinkedinProfilePictureBlock,
|
||||
LinkedinPersonLookupBlock,
|
||||
LinkedinRoleLookupBlock,
|
||||
)
|
||||
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
|
||||
from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
@@ -30,6 +36,7 @@ from backend.integrations.credentials_store import (
|
||||
anthropic_credentials,
|
||||
apollo_credentials,
|
||||
did_credentials,
|
||||
enrichlayer_credentials,
|
||||
groq_credentials,
|
||||
ideogram_credentials,
|
||||
jina_credentials,
|
||||
@@ -39,6 +46,7 @@ from backend.integrations.credentials_store import (
|
||||
replicate_credentials,
|
||||
revid_credentials,
|
||||
unreal_credentials,
|
||||
v0_credentials,
|
||||
)
|
||||
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
@@ -48,12 +56,18 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 2,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT41_MINI: 1,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
@@ -76,6 +90,8 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
@@ -107,6 +123,10 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: 1,
|
||||
LlmModel.V0_1_5_LG: 2,
|
||||
LlmModel.V0_1_0_MD: 1,
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -196,6 +216,23 @@ LLM_COST = (
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "llama_api"
|
||||
]
|
||||
# v0 by Vercel Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": v0_credentials.id,
|
||||
"provider": v0_credentials.provider,
|
||||
"type": v0_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "v0"
|
||||
]
|
||||
# AI/ML Api Models
|
||||
+ [
|
||||
BlockCost(
|
||||
@@ -368,6 +405,54 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
)
|
||||
],
|
||||
GetLinkedinProfileBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": enrichlayer_credentials.id,
|
||||
"provider": enrichlayer_credentials.provider,
|
||||
"type": enrichlayer_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
LinkedinPersonLookupBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": enrichlayer_credentials.id,
|
||||
"provider": enrichlayer_credentials.provider,
|
||||
"type": enrichlayer_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
LinkedinRoleLookupBlock: [
|
||||
BlockCost(
|
||||
cost_amount=3,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": enrichlayer_credentials.id,
|
||||
"provider": enrichlayer_credentials.provider,
|
||||
"type": enrichlayer_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
GetLinkedinProfilePictureBlock: [
|
||||
BlockCost(
|
||||
cost_amount=3,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": enrichlayer_credentials.id,
|
||||
"provider": enrichlayer_credentials.provider,
|
||||
"type": enrichlayer_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
SmartDecisionMakerBlock: LLM_COST,
|
||||
SearchOrganizationsBlock: [
|
||||
BlockCost(
|
||||
|
||||
@@ -286,11 +286,17 @@ class UserCreditBase(ABC):
|
||||
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": transaction_key, "userId": user_id}
|
||||
)
|
||||
|
||||
if transaction.isActive:
|
||||
return
|
||||
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}"):
|
||||
|
||||
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": transaction_key, "userId": user_id}
|
||||
)
|
||||
if transaction.isActive:
|
||||
return
|
||||
|
||||
user_balance, _ = await self._get_credits(user_id)
|
||||
await CreditTransaction.prisma().update(
|
||||
where={
|
||||
@@ -998,8 +1004,8 @@ def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
async def get_stripe_customer_id(user_id: str) -> str:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
if user.stripeCustomerId:
|
||||
return user.stripeCustomerId
|
||||
if user.stripe_customer_id:
|
||||
return user.stripe_customer_id
|
||||
|
||||
customer = stripe.Customer.create(
|
||||
name=user.name or "",
|
||||
@@ -1022,10 +1028,10 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig):
|
||||
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
if not user.topUpConfig:
|
||||
if not user.top_up_config:
|
||||
return AutoTopUpConfig(threshold=0, amount=0)
|
||||
|
||||
return AutoTopUpConfig.model_validate(user.topUpConfig)
|
||||
return AutoTopUpConfig.model_validate(user.top_up_config)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import zlib
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||
from uuid import uuid4
|
||||
@@ -50,6 +49,10 @@ prisma = Prisma(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_connected():
|
||||
return prisma.is_connected()
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Acquiring connection")
|
||||
async def connect():
|
||||
if prisma.is_connected():
|
||||
@@ -84,35 +87,50 @@ TRANSACTION_TIMEOUT = 15000 # 15 seconds - Increased from 5s to prevent timeout
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(timeout: int | None = None):
|
||||
async def transaction(timeout: int = TRANSACTION_TIMEOUT):
|
||||
"""
|
||||
Create a database transaction with optional timeout.
|
||||
|
||||
Args:
|
||||
timeout: Transaction timeout in milliseconds. If None, uses TRANSACTION_TIMEOUT (15s).
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = TRANSACTION_TIMEOUT
|
||||
|
||||
async with prisma.tx(timeout=timeout) as tx:
|
||||
yield tx
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def locked_transaction(key: str, timeout: int | None = None):
|
||||
async def locked_transaction(key: str, timeout: int = TRANSACTION_TIMEOUT):
|
||||
"""
|
||||
Create a database transaction with advisory lock.
|
||||
Create a transaction and take a per-key advisory *transaction* lock.
|
||||
|
||||
- Uses a 64-bit lock id via hashtextextended(key, 0) to avoid 32-bit collisions.
|
||||
- Bound by lock_timeout and statement_timeout so it won't block indefinitely.
|
||||
- Lock is held for the duration of the transaction and auto-released on commit/rollback.
|
||||
|
||||
Args:
|
||||
key: Lock key for advisory lock
|
||||
timeout: Transaction timeout in milliseconds. If None, uses TRANSACTION_TIMEOUT (15s).
|
||||
key: String lock key (e.g., "usr_trx_<uuid>").
|
||||
timeout: Transaction/lock/statement timeout in milliseconds.
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = TRANSACTION_TIMEOUT
|
||||
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction(timeout=timeout) as tx:
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
# Ensure we don't wait longer than desired
|
||||
# Note: SET LOCAL doesn't support parameterized queries, must use string interpolation
|
||||
await tx.execute_raw(f"SET LOCAL statement_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
|
||||
await tx.execute_raw(f"SET LOCAL lock_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
|
||||
|
||||
# Block until acquired or lock_timeout hits
|
||||
try:
|
||||
await tx.execute_raw(
|
||||
"SELECT pg_advisory_xact_lock(hashtextextended($1, 0))",
|
||||
key,
|
||||
)
|
||||
except Exception as e:
|
||||
# Normalize PG's lock timeout error to TimeoutError for callers
|
||||
if "lock timeout" in str(e).lower():
|
||||
raise TimeoutError(
|
||||
f"Could not acquire lock for key={key!r} within {timeout}ms"
|
||||
) from e
|
||||
raise
|
||||
|
||||
yield tx
|
||||
|
||||
|
||||
|
||||
@@ -33,12 +33,13 @@ from prisma.types import (
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Config
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
@@ -134,6 +135,10 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
default=None,
|
||||
description="Error message if any",
|
||||
)
|
||||
activity_status: str | None = Field(
|
||||
default=None,
|
||||
description="AI-generated summary of what the agent did",
|
||||
)
|
||||
|
||||
def to_db(self) -> GraphExecutionStats:
|
||||
return GraphExecutionStats(
|
||||
@@ -145,6 +150,7 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
node_count=self.node_exec_count,
|
||||
node_error_count=self.node_error_count,
|
||||
error=self.error,
|
||||
activity_status=self.activity_status,
|
||||
)
|
||||
|
||||
stats: Stats | None
|
||||
@@ -189,6 +195,7 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
if isinstance(stats.error, Exception)
|
||||
else stats.error
|
||||
),
|
||||
activity_status=stats.activity_status,
|
||||
)
|
||||
if stats
|
||||
else None
|
||||
@@ -311,18 +318,30 @@ class NodeExecutionResult(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if _node_exec.executionData:
|
||||
# Execution that has been queued for execution will persist its data.
|
||||
try:
|
||||
stats = NodeExecutionStats.model_validate(_node_exec.stats or {})
|
||||
except (ValueError, ValidationError):
|
||||
stats = NodeExecutionStats()
|
||||
|
||||
if stats.cleared_inputs:
|
||||
input_data: BlockInput = defaultdict()
|
||||
for name, messages in stats.cleared_inputs.items():
|
||||
input_data[name] = messages[-1] if messages else ""
|
||||
elif _node_exec.executionData:
|
||||
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
|
||||
else:
|
||||
# For incomplete execution, executionData will not be yet available.
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in _node_exec.Input or []:
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
for data in _node_exec.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
|
||||
if stats.cleared_outputs:
|
||||
for name, messages in stats.cleared_outputs.items():
|
||||
output_data[name].extend(messages)
|
||||
else:
|
||||
for data in _node_exec.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
|
||||
if graph_execution:
|
||||
@@ -630,6 +649,8 @@ async def update_graph_execution_stats(
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
# Terminated graph can be resumed.
|
||||
{"executionStatus": ExecutionStatus.TERMINATED},
|
||||
],
|
||||
},
|
||||
data=update_data,
|
||||
@@ -646,27 +667,6 @@ async def update_graph_execution_stats(
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
async def update_node_execution_stats(
|
||||
node_exec_id: str, stats: NodeExecutionStats
|
||||
) -> NodeExecutionResult:
|
||||
data = stats.model_dump()
|
||||
if isinstance(data["error"], Exception):
|
||||
data["error"] = str(data["error"])
|
||||
|
||||
res = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data={
|
||||
"stats": SafeJson(data),
|
||||
"endedTime": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Node execution {node_exec_id} not found.")
|
||||
|
||||
return NodeExecutionResult.from_db(res)
|
||||
|
||||
|
||||
async def update_node_execution_status_batch(
|
||||
node_exec_ids: list[str],
|
||||
status: ExecutionStatus,
|
||||
@@ -896,15 +896,15 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
|
||||
def publish(self, res: GraphExecution | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecution):
|
||||
self.publish_graph_exec_update(res)
|
||||
self._publish_graph_exec_update(res)
|
||||
else:
|
||||
self.publish_node_exec_update(res)
|
||||
self._publish_node_exec_update(res)
|
||||
|
||||
def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
def _publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
def publish_graph_exec_update(self, res: GraphExecution):
|
||||
def _publish_graph_exec_update(self, res: GraphExecution):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
|
||||
@@ -936,17 +936,18 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
@func_retry
|
||||
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecutionMeta):
|
||||
await self.publish_graph_exec_update(res)
|
||||
await self._publish_graph_exec_update(res)
|
||||
else:
|
||||
await self.publish_node_exec_update(res)
|
||||
await self._publish_node_exec_update(res)
|
||||
|
||||
async def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
async def _publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
async def _publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
|
||||
# Add default empty values for compatibility
|
||||
event_data = res.model_dump()
|
||||
|
||||
109
autogpt_platform/backend/backend/data/generate_data.py
Normal file
109
autogpt_platform/backend/backend/data/generate_data.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
|
||||
from backend.data.execution import get_graph_executions
|
||||
from backend.data.graph import get_graph_metadata
|
||||
from backend.data.model import UserExecutionSummaryStats
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[SummaryData]")
|
||||
|
||||
|
||||
async def get_user_execution_summary_data(
|
||||
user_id: str, start_time: datetime, end_time: datetime
|
||||
) -> UserExecutionSummaryStats:
|
||||
"""Gather all summary data for a user in a time range.
|
||||
|
||||
This function fetches graph executions once and aggregates all required
|
||||
statistics in a single pass for efficiency.
|
||||
"""
|
||||
try:
|
||||
# Fetch graph executions once
|
||||
executions = await get_graph_executions(
|
||||
user_id=user_id,
|
||||
created_time_gte=start_time,
|
||||
created_time_lte=end_time,
|
||||
)
|
||||
|
||||
# Initialize aggregation variables
|
||||
total_credits_used = 0.0
|
||||
total_executions = len(executions)
|
||||
successful_runs = 0
|
||||
failed_runs = 0
|
||||
terminated_runs = 0
|
||||
execution_times = []
|
||||
agent_usage = defaultdict(int)
|
||||
cost_by_graph_id = defaultdict(float)
|
||||
|
||||
# Single pass through executions to aggregate all stats
|
||||
for execution in executions:
|
||||
# Count execution statuses (including TERMINATED as failed)
|
||||
if execution.status == AgentExecutionStatus.COMPLETED:
|
||||
successful_runs += 1
|
||||
elif execution.status == AgentExecutionStatus.FAILED:
|
||||
failed_runs += 1
|
||||
elif execution.status == AgentExecutionStatus.TERMINATED:
|
||||
terminated_runs += 1
|
||||
|
||||
# Aggregate costs from stats
|
||||
if execution.stats and hasattr(execution.stats, "cost"):
|
||||
cost_in_dollars = execution.stats.cost / 100
|
||||
total_credits_used += cost_in_dollars
|
||||
cost_by_graph_id[execution.graph_id] += cost_in_dollars
|
||||
|
||||
# Collect execution times
|
||||
if execution.stats and hasattr(execution.stats, "duration"):
|
||||
execution_times.append(execution.stats.duration)
|
||||
|
||||
# Count agent usage
|
||||
agent_usage[execution.graph_id] += 1
|
||||
|
||||
# Calculate derived stats
|
||||
total_execution_time = sum(execution_times)
|
||||
average_execution_time = (
|
||||
total_execution_time / len(execution_times) if execution_times else 0
|
||||
)
|
||||
|
||||
# Find most used agent
|
||||
most_used_agent = "No agents used"
|
||||
if agent_usage:
|
||||
most_used_agent_id = max(agent_usage, key=lambda k: agent_usage[k])
|
||||
try:
|
||||
graph_meta = await get_graph_metadata(graph_id=most_used_agent_id)
|
||||
most_used_agent = (
|
||||
graph_meta.name if graph_meta else f"Agent {most_used_agent_id[:8]}"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Could not get metadata for graph {most_used_agent_id}")
|
||||
most_used_agent = f"Agent {most_used_agent_id[:8]}"
|
||||
|
||||
# Convert graph_ids to agent names for cost breakdown
|
||||
cost_breakdown = {}
|
||||
for graph_id, cost in cost_by_graph_id.items():
|
||||
try:
|
||||
graph_meta = await get_graph_metadata(graph_id=graph_id)
|
||||
agent_name = graph_meta.name if graph_meta else f"Agent {graph_id[:8]}"
|
||||
except Exception:
|
||||
logger.warning(f"Could not get metadata for graph {graph_id}")
|
||||
agent_name = f"Agent {graph_id[:8]}"
|
||||
cost_breakdown[agent_name] = cost
|
||||
|
||||
# Build the summary stats object (include terminated runs as failed)
|
||||
return UserExecutionSummaryStats(
|
||||
total_credits_used=total_credits_used,
|
||||
total_executions=total_executions,
|
||||
successful_runs=successful_runs,
|
||||
failed_runs=failed_runs + terminated_runs,
|
||||
most_used_agent=most_used_agent,
|
||||
total_execution_time=total_execution_time,
|
||||
average_execution_time=average_execution_time,
|
||||
cost_breakdown=cost_breakdown,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user summary data: {e}")
|
||||
raise DatabaseError(f"Failed to get user summary data: {e}") from e
|
||||
@@ -416,6 +416,10 @@ class GraphModel(Graph):
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
"""
|
||||
Validate graph structure and raise `ValueError` on issues.
|
||||
For structured error reporting, use `validate_graph_get_errors`.
|
||||
"""
|
||||
self._validate_graph(self, for_run, nodes_input_masks)
|
||||
for sub_graph in self.sub_graphs:
|
||||
self._validate_graph(sub_graph, for_run, nodes_input_masks)
|
||||
@@ -425,15 +429,58 @@ class GraphModel(Graph):
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
def is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
) -> None:
|
||||
errors = GraphModel._validate_graph_get_errors(
|
||||
graph, for_run, nodes_input_masks
|
||||
)
|
||||
if errors:
|
||||
# Just raise the first error for backward compatibility
|
||||
first_error = next(iter(errors.values()))
|
||||
first_field_error = next(iter(first_error.values()))
|
||||
raise ValueError(first_field_error)
|
||||
|
||||
def sanitize(name):
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
if is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
def validate_graph_get_errors(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
|
||||
Returns: dict[node_id, dict[field_name, error_message]]
|
||||
"""
|
||||
return {
|
||||
**self._validate_graph_get_errors(self, for_run, nodes_input_masks),
|
||||
**{
|
||||
node_id: error
|
||||
for sub_graph in self.sub_graphs
|
||||
for node_id, error in self._validate_graph_get_errors(
|
||||
sub_graph, for_run, nodes_input_masks
|
||||
).items()
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph_get_errors(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph and return structured errors per node.
|
||||
|
||||
Returns: dict[node_id, dict[field_name, error_message]]
|
||||
"""
|
||||
# First, check for structural issues with the graph
|
||||
try:
|
||||
GraphModel._validate_graph_structure(graph)
|
||||
except ValueError:
|
||||
# If structural validation fails, we can't provide per-node errors
|
||||
# so we re-raise as is
|
||||
raise
|
||||
|
||||
# Collect errors per node
|
||||
node_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
|
||||
# Validate smart decision maker nodes
|
||||
nodes_block = {
|
||||
@@ -442,7 +489,7 @@ class GraphModel(Graph):
|
||||
if (block := get_block(node.block_id)) is not None
|
||||
}
|
||||
|
||||
input_links = defaultdict(list)
|
||||
input_links: dict[str, list[Link]] = defaultdict(list)
|
||||
|
||||
for link in graph.links:
|
||||
input_links[link.sink_id].append(link)
|
||||
@@ -450,17 +497,22 @@ class GraphModel(Graph):
|
||||
# Nodes: required fields are filled or connected and dependencies are satisfied
|
||||
for node in graph.nodes:
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
# For invalid blocks, we still raise immediately as this is a structural issue
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
node_input_mask = (
|
||||
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
||||
)
|
||||
provided_inputs = set(
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
[_sanitize_pin_name(name) for name in node.input_default]
|
||||
+ [
|
||||
_sanitize_pin_name(link.sink_name)
|
||||
for link in input_links.get(node.id, [])
|
||||
]
|
||||
+ ([name for name in node_input_mask] if node_input_mask else [])
|
||||
)
|
||||
InputSchema = block.input_schema
|
||||
|
||||
for name in (required_fields := InputSchema.get_required_fields()):
|
||||
if (
|
||||
name not in provided_inputs
|
||||
@@ -477,18 +529,16 @@ class GraphModel(Graph):
|
||||
]
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
node_errors[node.id][name] = "This field is required"
|
||||
|
||||
if (
|
||||
block.block_type == BlockType.INPUT
|
||||
and (input_key := node.input_default.get("name"))
|
||||
and is_credentials_field_name(input_key)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Agent input node uses reserved name '{input_key}'; "
|
||||
"'credentials' and `*_credentials` are reserved input names"
|
||||
node_errors[node.id]["name"] = (
|
||||
f"'{input_key}' is a reserved input name: "
|
||||
"'credentials' and `*_credentials` are reserved"
|
||||
)
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
@@ -538,10 +588,15 @@ class GraphModel(Graph):
|
||||
# Check for missing dependencies when dependent field is present
|
||||
missing_deps = [dep for dep in dependencies if not has_value(node, dep)]
|
||||
if missing_deps and (field_has_value or field_is_required):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
|
||||
)
|
||||
node_errors[node.id][
|
||||
field_name
|
||||
] = f"Requires {', '.join(missing_deps)} to be set"
|
||||
|
||||
return node_errors
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph_structure(graph: BaseGraph):
|
||||
"""Validate graph structure (links, connections, etc.)"""
|
||||
node_map = {v.id: v for v in graph.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
@@ -567,7 +622,7 @@ class GraphModel(Graph):
|
||||
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
||||
)
|
||||
|
||||
sanitized_name = sanitize(name)
|
||||
sanitized_name = _sanitize_pin_name(name)
|
||||
vals = node.input_default
|
||||
if i == 0:
|
||||
fields = (
|
||||
@@ -581,7 +636,7 @@ class GraphModel(Graph):
|
||||
if block.block_type not in [BlockType.AGENT]
|
||||
else vals.get("input_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
if sanitized_name not in fields and not is_tool_pin(name):
|
||||
if sanitized_name not in fields and not _is_tool_pin(name):
|
||||
fields_msg = f"Allowed fields: {fields}"
|
||||
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
|
||||
|
||||
@@ -618,6 +673,17 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
|
||||
def _is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
|
||||
|
||||
def _sanitize_pin_name(name: str) -> str:
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
if _is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
|
||||
|
||||
class GraphMeta(Graph):
|
||||
user_id: str
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import enum
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -40,12 +41,120 @@ from pydantic_core import (
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
AnyProviderName = str # Will be validated as ProviderName at runtime
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""Application-layer User model with snake_case convention."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
str_strip_whitespace=True,
|
||||
)
|
||||
|
||||
id: str = Field(..., description="User ID")
|
||||
email: str = Field(..., description="User email address")
|
||||
email_verified: bool = Field(default=True, description="Whether email is verified")
|
||||
name: Optional[str] = Field(None, description="User display name")
|
||||
created_at: datetime = Field(..., description="When user was created")
|
||||
updated_at: datetime = Field(..., description="When user was last updated")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="User metadata as dict"
|
||||
)
|
||||
integrations: str = Field(default="", description="Encrypted integrations data")
|
||||
stripe_customer_id: Optional[str] = Field(None, description="Stripe customer ID")
|
||||
top_up_config: Optional["AutoTopUpConfig"] = Field(
|
||||
None, description="Top up configuration"
|
||||
)
|
||||
|
||||
# Notification preferences
|
||||
max_emails_per_day: int = Field(default=3, description="Maximum emails per day")
|
||||
notify_on_agent_run: bool = Field(default=True, description="Notify on agent run")
|
||||
notify_on_zero_balance: bool = Field(
|
||||
default=True, description="Notify on zero balance"
|
||||
)
|
||||
notify_on_low_balance: bool = Field(
|
||||
default=True, description="Notify on low balance"
|
||||
)
|
||||
notify_on_block_execution_failed: bool = Field(
|
||||
default=True, description="Notify on block execution failure"
|
||||
)
|
||||
notify_on_continuous_agent_error: bool = Field(
|
||||
default=True, description="Notify on continuous agent error"
|
||||
)
|
||||
notify_on_daily_summary: bool = Field(
|
||||
default=True, description="Notify on daily summary"
|
||||
)
|
||||
notify_on_weekly_summary: bool = Field(
|
||||
default=True, description="Notify on weekly summary"
|
||||
)
|
||||
notify_on_monthly_summary: bool = Field(
|
||||
default=True, description="Notify on monthly summary"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
# Handle metadata field - convert from JSON string or dict to dict
|
||||
metadata = {}
|
||||
if prisma_user.metadata:
|
||||
if isinstance(prisma_user.metadata, str):
|
||||
try:
|
||||
metadata = json_loads(prisma_user.metadata)
|
||||
except (JSONDecodeError, TypeError):
|
||||
metadata = {}
|
||||
elif isinstance(prisma_user.metadata, dict):
|
||||
metadata = prisma_user.metadata
|
||||
|
||||
# Handle topUpConfig field
|
||||
top_up_config = None
|
||||
if prisma_user.topUpConfig:
|
||||
if isinstance(prisma_user.topUpConfig, str):
|
||||
try:
|
||||
config_dict = json_loads(prisma_user.topUpConfig)
|
||||
top_up_config = AutoTopUpConfig.model_validate(config_dict)
|
||||
except (JSONDecodeError, TypeError, ValueError):
|
||||
top_up_config = None
|
||||
elif isinstance(prisma_user.topUpConfig, dict):
|
||||
try:
|
||||
top_up_config = AutoTopUpConfig.model_validate(
|
||||
prisma_user.topUpConfig
|
||||
)
|
||||
except ValueError:
|
||||
top_up_config = None
|
||||
|
||||
return cls(
|
||||
id=prisma_user.id,
|
||||
email=prisma_user.email,
|
||||
email_verified=prisma_user.emailVerified or True,
|
||||
name=prisma_user.name,
|
||||
created_at=prisma_user.createdAt,
|
||||
updated_at=prisma_user.updatedAt,
|
||||
metadata=metadata,
|
||||
integrations=prisma_user.integrations or "",
|
||||
stripe_customer_id=prisma_user.stripeCustomerId,
|
||||
top_up_config=top_up_config,
|
||||
max_emails_per_day=prisma_user.maxEmailsPerDay or 3,
|
||||
notify_on_agent_run=prisma_user.notifyOnAgentRun or True,
|
||||
notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True,
|
||||
notify_on_low_balance=prisma_user.notifyOnLowBalance or True,
|
||||
notify_on_block_execution_failed=prisma_user.notifyOnBlockExecutionFailed
|
||||
or True,
|
||||
notify_on_continuous_agent_error=prisma_user.notifyOnContinuousAgentError
|
||||
or True,
|
||||
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from prisma.models import User as PrismaUser
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -644,7 +753,7 @@ class NodeExecutionStats(BaseModel):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
error: Optional[BaseException | str] = None
|
||||
walltime: float = 0
|
||||
cputime: float = 0
|
||||
input_size: int = 0
|
||||
@@ -655,6 +764,9 @@ class NodeExecutionStats(BaseModel):
|
||||
output_token_count: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
# Moderation fields
|
||||
cleared_inputs: Optional[dict[str, list[str]]] = None
|
||||
cleared_outputs: Optional[dict[str, list[str]]] = None
|
||||
|
||||
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
"""Mutate this instance by adding another NodeExecutionStats."""
|
||||
@@ -706,3 +818,24 @@ class GraphExecutionStats(BaseModel):
|
||||
default=0, description="Total number of errors generated"
|
||||
)
|
||||
cost: int = Field(default=0, description="Total execution cost (cents)")
|
||||
activity_status: Optional[str] = Field(
|
||||
default=None, description="AI-generated summary of what the agent did"
|
||||
)
|
||||
|
||||
|
||||
class UserExecutionSummaryStats(BaseModel):
|
||||
"""Summary of user statistics for a specific user."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
total_credits_used: float = Field(default=0)
|
||||
total_executions: int = Field(default=0)
|
||||
successful_runs: int = Field(default=0)
|
||||
failed_runs: int = Field(default=0)
|
||||
most_used_agent: str = Field(default="")
|
||||
total_execution_time: float = Field(default=0)
|
||||
average_execution_time: float = Field(default=0)
|
||||
cost_breakdown: dict[str, float] = Field(default_factory=dict)
|
||||
|
||||
@@ -4,20 +4,12 @@ from enum import Enum
|
||||
from typing import Awaitable, Optional
|
||||
|
||||
import aio_pika
|
||||
import aio_pika.exceptions as aio_ex
|
||||
import pika
|
||||
import pika.adapters.blocking_connection
|
||||
from pika.exceptions import AMQPError
|
||||
from pika.spec import BasicProperties
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.retry import conn_retry, func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -148,6 +140,7 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
socket_timeout=SOCKET_TIMEOUT,
|
||||
connection_attempts=CONNECTION_ATTEMPTS,
|
||||
retry_delay=RETRY_DELAY,
|
||||
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
|
||||
)
|
||||
|
||||
self._connection = pika.BlockingConnection(parameters)
|
||||
@@ -198,12 +191,7 @@ class SyncRabbitMQ(RabbitMQBase):
|
||||
routing_key=queue.routing_key or queue.name,
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
@func_retry
|
||||
def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
@@ -257,6 +245,7 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
password=self.password,
|
||||
virtualhost=self.config.vhost.lstrip("/"),
|
||||
blocked_connection_timeout=BLOCKED_CONNECTION_TIMEOUT,
|
||||
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
|
||||
)
|
||||
self._channel = await self._connection.channel()
|
||||
await self._channel.set_qos(prefetch_count=1)
|
||||
@@ -302,12 +291,7 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange, routing_key=queue.routing_key or queue.name
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
|
||||
wait=wait_random_exponential(multiplier=1, max=5),
|
||||
stop=stop_after_attempt(5),
|
||||
reraise=True,
|
||||
)
|
||||
@func_retry
|
||||
async def publish_message(
|
||||
self,
|
||||
routing_key: str,
|
||||
|
||||
@@ -9,11 +9,11 @@ from urllib.parse import quote_plus
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
from fastapi import HTTPException
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
|
||||
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.encryption import JSONCryptor
|
||||
@@ -21,6 +21,7 @@ from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
async def get_or_create_user(user_data: dict) -> User:
|
||||
@@ -43,7 +44,7 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
)
|
||||
)
|
||||
|
||||
return User.model_validate(user)
|
||||
return User.from_db(user)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
||||
|
||||
@@ -52,7 +53,7 @@ async def get_user_by_id(user_id: str) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
return User.model_validate(user)
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
@@ -66,7 +67,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
async def get_user_by_email(email: str) -> Optional[User]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"email": email})
|
||||
return User.model_validate(user) if user else None
|
||||
return User.from_db(user) if user else None
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
|
||||
|
||||
@@ -90,27 +91,11 @@ async def create_default_user() -> Optional[User]:
|
||||
name="Default User",
|
||||
)
|
||||
)
|
||||
return User.model_validate(user)
|
||||
|
||||
|
||||
async def get_user_metadata(user_id: str) -> UserMetadata:
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
|
||||
metadata = cast(UserMetadataRaw, user.metadata)
|
||||
return UserMetadata.model_validate(metadata)
|
||||
|
||||
|
||||
async def update_user_metadata(user_id: str, metadata: UserMetadata):
|
||||
await User.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"metadata": SafeJson(metadata.model_dump())},
|
||||
)
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
async def get_user_integrations(user_id: str) -> UserIntegrations:
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
|
||||
@@ -125,7 +110,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
|
||||
|
||||
async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
|
||||
await User.prisma().update(
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"integrations": encrypted_data},
|
||||
)
|
||||
@@ -133,7 +118,7 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
|
||||
async def migrate_and_encrypt_user_integrations():
|
||||
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
|
||||
users = await User.prisma().find_many(
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
where={
|
||||
"metadata": cast(
|
||||
JsonFilter,
|
||||
@@ -169,7 +154,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
raw_metadata.pop("integration_oauth_states", None)
|
||||
|
||||
# Update metadata without integration data
|
||||
await User.prisma().update(
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user.id},
|
||||
data={"metadata": SafeJson(raw_metadata)},
|
||||
)
|
||||
@@ -177,7 +162,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
|
||||
async def get_active_user_ids_in_timerange(start_time: str, end_time: str) -> list[str]:
|
||||
try:
|
||||
users = await User.prisma().find_many(
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
where={
|
||||
"AgentGraphExecutions": {
|
||||
"some": {
|
||||
@@ -207,7 +192,7 @@ async def get_active_users_ids() -> list[str]:
|
||||
|
||||
async def get_user_notification_preference(user_id: str) -> NotificationPreference:
|
||||
try:
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
|
||||
@@ -284,7 +269,7 @@ async def update_user_notification_preference(
|
||||
if data.daily_limit:
|
||||
update_data["maxEmailsPerDay"] = data.daily_limit
|
||||
|
||||
user = await User.prisma().update(
|
||||
user = await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -322,7 +307,7 @@ async def update_user_notification_preference(
|
||||
async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
"""Set the email verification status for a user."""
|
||||
try:
|
||||
await User.prisma().update(
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"emailVerified": verified},
|
||||
)
|
||||
@@ -335,7 +320,7 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
async def get_user_email_verification(user_id: str) -> bool:
|
||||
"""Get the email verification status for a user."""
|
||||
try:
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
return user.emailVerified
|
||||
@@ -348,7 +333,7 @@ async def get_user_email_verification(user_id: str) -> bool:
|
||||
def generate_unsubscribe_link(user_id: str) -> str:
|
||||
"""Generate a link to unsubscribe from all notifications"""
|
||||
# Create an HMAC using a secret key
|
||||
secret_key = Settings().secrets.unsubscribe_secret_key
|
||||
secret_key = settings.secrets.unsubscribe_secret_key
|
||||
signature = hmac.new(
|
||||
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
@@ -359,7 +344,7 @@ def generate_unsubscribe_link(user_id: str) -> str:
|
||||
).decode("utf-8")
|
||||
logger.info(f"Generating unsubscribe link for user {user_id}")
|
||||
|
||||
base_url = Settings().config.platform_base_url
|
||||
base_url = settings.config.platform_base_url
|
||||
return f"{base_url}/api/email/unsubscribe?token={quote_plus(token)}"
|
||||
|
||||
|
||||
@@ -371,7 +356,7 @@ async def unsubscribe_user_by_token(token: str) -> None:
|
||||
user_id, received_signature_hex = decoded.split(":", 1)
|
||||
|
||||
# Verify the signature
|
||||
secret_key = Settings().secrets.unsubscribe_secret_key
|
||||
secret_key = settings.secrets.unsubscribe_secret_key
|
||||
expected_signature = hmac.new(
|
||||
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
|
||||
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Module for generating AI-based activity status for graph executions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.llm import LlmModel, llm_call
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import APIKeyCredentials, GraphExecutionStats
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorInfo(TypedDict):
|
||||
"""Type definition for error information."""
|
||||
|
||||
error: str
|
||||
execution_id: str
|
||||
timestamp: str
|
||||
|
||||
|
||||
class InputOutputInfo(TypedDict):
|
||||
"""Type definition for input/output information."""
|
||||
|
||||
execution_id: str
|
||||
output_data: dict[str, Any] # Used for both input and output data
|
||||
timestamp: str
|
||||
|
||||
|
||||
class NodeInfo(TypedDict):
|
||||
"""Type definition for node information."""
|
||||
|
||||
node_id: str
|
||||
block_id: str
|
||||
block_name: str
|
||||
block_description: str
|
||||
execution_count: int
|
||||
error_count: int
|
||||
recent_errors: list[ErrorInfo]
|
||||
recent_outputs: list[InputOutputInfo]
|
||||
recent_inputs: list[InputOutputInfo]
|
||||
|
||||
|
||||
class NodeRelation(TypedDict):
|
||||
"""Type definition for node relation information."""
|
||||
|
||||
source_node_id: str
|
||||
sink_node_id: str
|
||||
source_name: str
|
||||
sink_name: str
|
||||
is_static: bool
|
||||
source_block_name: NotRequired[str] # Optional, only set if block exists
|
||||
sink_block_name: NotRequired[str] # Optional, only set if block exists
|
||||
|
||||
|
||||
def _truncate_uuid(uuid_str: str) -> str:
|
||||
"""Truncate UUID to first segment to reduce payload size."""
|
||||
if not uuid_str:
|
||||
return uuid_str
|
||||
return uuid_str.split("-")[0] if "-" in uuid_str else uuid_str[:8]
|
||||
|
||||
|
||||
async def generate_activity_status_for_execution(
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
user_id: str,
|
||||
execution_status: ExecutionStatus | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Generate an AI-based activity status summary for a graph execution.
|
||||
|
||||
This function handles all the data collection and AI generation logic,
|
||||
keeping the manager integration simple.
|
||||
|
||||
Args:
|
||||
graph_exec_id: The graph execution ID
|
||||
graph_id: The graph ID
|
||||
graph_version: The graph version
|
||||
execution_stats: Execution statistics
|
||||
db_client: Database client for fetching data
|
||||
user_id: User ID for LaunchDarkly feature flag evaluation
|
||||
execution_status: The overall execution status (COMPLETED, FAILED, TERMINATED)
|
||||
|
||||
Returns:
|
||||
AI-generated activity status string, or None if feature is disabled
|
||||
"""
|
||||
# Check LaunchDarkly feature flag for AI activity status generation with full context support
|
||||
if not await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
|
||||
logger.debug("AI activity status generation is disabled via LaunchDarkly")
|
||||
return None
|
||||
|
||||
# Check if we have OpenAI API key
|
||||
try:
|
||||
settings = Settings()
|
||||
if not settings.secrets.openai_api_key:
|
||||
logger.debug(
|
||||
"OpenAI API key not configured, skipping activity status generation"
|
||||
)
|
||||
return None
|
||||
|
||||
# Get all node executions for this graph execution
|
||||
node_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, include_exec_data=True
|
||||
)
|
||||
|
||||
# Get graph metadata and full graph structure for name, description, and links
|
||||
graph_metadata = await db_client.get_graph_metadata(graph_id, graph_version)
|
||||
graph = await db_client.get_graph(graph_id, graph_version)
|
||||
|
||||
graph_name = graph_metadata.name if graph_metadata else f"Graph {graph_id}"
|
||||
graph_description = graph_metadata.description if graph_metadata else ""
|
||||
graph_links = graph.links if graph else []
|
||||
|
||||
# Build execution data summary
|
||||
execution_data = _build_execution_summary(
|
||||
node_executions,
|
||||
execution_stats,
|
||||
graph_name,
|
||||
graph_description,
|
||||
graph_links,
|
||||
execution_status,
|
||||
)
|
||||
|
||||
# Prepare prompt for AI
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an AI assistant summarizing what you just did for a user in simple, friendly language. "
|
||||
"Write from the user's perspective about what they accomplished, NOT about technical execution details. "
|
||||
"Focus on the ACTUAL TASK the user wanted done, not the internal workflow steps. "
|
||||
"Avoid technical terms like 'workflow', 'execution', 'components', 'nodes', 'processing', etc. "
|
||||
"Keep it to 3 sentences maximum. Be conversational and human-friendly.\n\n"
|
||||
"IMPORTANT: Be HONEST about what actually happened:\n"
|
||||
"- If the input was invalid/nonsensical, say so directly\n"
|
||||
"- If the task failed, explain what went wrong in simple terms\n"
|
||||
"- If errors occurred, focus on what the user needs to know\n"
|
||||
"- Only claim success if the task was genuinely completed\n"
|
||||
"- Don't sugar-coat failures or present them as helpful feedback\n\n"
|
||||
"Understanding Errors:\n"
|
||||
"- Node errors: Individual steps may fail but the overall task might still complete (e.g., one data source fails but others work)\n"
|
||||
"- Graph error (in overall_status.graph_error): This means the entire execution failed and nothing was accomplished\n"
|
||||
"- Even if execution shows 'completed', check if critical nodes failed that would prevent the desired outcome\n"
|
||||
"- Focus on the end result the user wanted, not whether technical steps completed"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"A user ran '{graph_name}' to accomplish something. Based on this execution data, "
|
||||
f"write what they achieved in simple, user-friendly terms:\n\n"
|
||||
f"{json.dumps(execution_data, indent=2)}\n\n"
|
||||
"CRITICAL: Check overall_status.graph_error FIRST - if present, the entire execution failed.\n"
|
||||
"Then check individual node errors to understand partial failures.\n\n"
|
||||
"Write 1-3 sentences about what the user accomplished, such as:\n"
|
||||
"- 'I analyzed your resume and provided detailed feedback for the IT industry.'\n"
|
||||
"- 'I couldn't analyze your resume because the input was just nonsensical text.'\n"
|
||||
"- 'I failed to complete the task due to missing API access.'\n"
|
||||
"- 'I extracted key information from your documents and organized it into a summary.'\n"
|
||||
"- 'The task failed to run due to system configuration issues.'\n\n"
|
||||
"Focus on what ACTUALLY happened, not what was attempted."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
# Log the prompt for debugging purposes
|
||||
logger.debug(
|
||||
f"Sending prompt to LLM for graph execution {graph_exec_id}: {json.dumps(prompt, indent=2)}"
|
||||
)
|
||||
|
||||
# Create credentials for LLM call
|
||||
credentials = APIKeyCredentials(
|
||||
id="openai",
|
||||
provider="openai",
|
||||
api_key=SecretStr(settings.secrets.openai_api_key),
|
||||
title="System OpenAI",
|
||||
)
|
||||
|
||||
# Make LLM call using current event loop
|
||||
activity_status = await _call_llm_direct(credentials, prompt)
|
||||
|
||||
logger.debug(
|
||||
f"Generated activity status for {graph_exec_id}: {activity_status}"
|
||||
)
|
||||
return activity_status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate activity status for execution {graph_exec_id}: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _build_execution_summary(
|
||||
node_executions: list[NodeExecutionResult],
|
||||
execution_stats: GraphExecutionStats,
|
||||
graph_name: str,
|
||||
graph_description: str,
|
||||
graph_links: list[Any],
|
||||
execution_status: ExecutionStatus | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a structured summary of execution data for AI analysis."""
|
||||
|
||||
nodes: list[NodeInfo] = []
|
||||
node_execution_counts: dict[str, int] = {}
|
||||
node_error_counts: dict[str, int] = {}
|
||||
node_errors: dict[str, list[ErrorInfo]] = {}
|
||||
node_outputs: dict[str, list[InputOutputInfo]] = {}
|
||||
node_inputs: dict[str, list[InputOutputInfo]] = {}
|
||||
input_output_data: dict[str, Any] = {}
|
||||
node_map: dict[str, NodeInfo] = {}
|
||||
|
||||
# Process node executions
|
||||
for node_exec in node_executions:
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.warning(
|
||||
f"Block {node_exec.block_id} not found for node {node_exec.node_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Track execution counts per node
|
||||
if node_exec.node_id not in node_execution_counts:
|
||||
node_execution_counts[node_exec.node_id] = 0
|
||||
node_execution_counts[node_exec.node_id] += 1
|
||||
|
||||
# Track errors per node and group them
|
||||
if node_exec.status == ExecutionStatus.FAILED:
|
||||
if node_exec.node_id not in node_error_counts:
|
||||
node_error_counts[node_exec.node_id] = 0
|
||||
node_error_counts[node_exec.node_id] += 1
|
||||
|
||||
# Extract actual error message from output_data
|
||||
error_message = "Unknown error"
|
||||
if node_exec.output_data and isinstance(node_exec.output_data, dict):
|
||||
# Check if error is in output_data
|
||||
if "error" in node_exec.output_data:
|
||||
error_output = node_exec.output_data["error"]
|
||||
if isinstance(error_output, list) and error_output:
|
||||
error_message = str(error_output[0])
|
||||
else:
|
||||
error_message = str(error_output)
|
||||
|
||||
# Group errors by node_id
|
||||
if node_exec.node_id not in node_errors:
|
||||
node_errors[node_exec.node_id] = []
|
||||
|
||||
node_errors[node_exec.node_id].append(
|
||||
{
|
||||
"error": error_message,
|
||||
"execution_id": _truncate_uuid(node_exec.node_exec_id),
|
||||
"timestamp": node_exec.add_time.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
# Collect output samples for each node (latest executions)
|
||||
if node_exec.output_data:
|
||||
if node_exec.node_id not in node_outputs:
|
||||
node_outputs[node_exec.node_id] = []
|
||||
|
||||
# Truncate output data to 100 chars to save space
|
||||
truncated_output = truncate(node_exec.output_data, 100)
|
||||
|
||||
node_outputs[node_exec.node_id].append(
|
||||
{
|
||||
"execution_id": _truncate_uuid(node_exec.node_exec_id),
|
||||
"output_data": truncated_output,
|
||||
"timestamp": node_exec.add_time.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
# Collect input samples for each node (latest executions)
|
||||
if node_exec.input_data:
|
||||
if node_exec.node_id not in node_inputs:
|
||||
node_inputs[node_exec.node_id] = []
|
||||
|
||||
# Truncate input data to 100 chars to save space
|
||||
truncated_input = truncate(node_exec.input_data, 100)
|
||||
|
||||
node_inputs[node_exec.node_id].append(
|
||||
{
|
||||
"execution_id": _truncate_uuid(node_exec.node_exec_id),
|
||||
"output_data": truncated_input, # Reuse field name for consistency
|
||||
"timestamp": node_exec.add_time.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
# Build node data (only add unique nodes)
|
||||
if node_exec.node_id not in node_map:
|
||||
node_data: NodeInfo = {
|
||||
"node_id": _truncate_uuid(node_exec.node_id),
|
||||
"block_id": _truncate_uuid(node_exec.block_id),
|
||||
"block_name": block.name,
|
||||
"block_description": block.description or "",
|
||||
"execution_count": 0, # Will be set later
|
||||
"error_count": 0, # Will be set later
|
||||
"recent_errors": [], # Will be set later
|
||||
"recent_outputs": [], # Will be set later
|
||||
"recent_inputs": [], # Will be set later
|
||||
}
|
||||
nodes.append(node_data)
|
||||
node_map[node_exec.node_id] = node_data
|
||||
|
||||
# Store input/output data for special blocks (input/output blocks)
|
||||
if block.name in ["AgentInputBlock", "AgentOutputBlock", "UserInputBlock"]:
|
||||
if node_exec.input_data:
|
||||
input_output_data[f"{node_exec.node_id}_inputs"] = dict(
|
||||
node_exec.input_data
|
||||
)
|
||||
if node_exec.output_data:
|
||||
input_output_data[f"{node_exec.node_id}_outputs"] = dict(
|
||||
node_exec.output_data
|
||||
)
|
||||
|
||||
# Add execution and error counts to node data, plus limited errors and output samples
|
||||
for node in nodes:
|
||||
# Use original node_id for lookups (before truncation)
|
||||
original_node_id = None
|
||||
for orig_id, node_data in node_map.items():
|
||||
if node_data == node:
|
||||
original_node_id = orig_id
|
||||
break
|
||||
|
||||
if original_node_id:
|
||||
node["execution_count"] = node_execution_counts.get(original_node_id, 0)
|
||||
node["error_count"] = node_error_counts.get(original_node_id, 0)
|
||||
|
||||
# Add limited errors for this node (latest 10 or first 5 + last 5)
|
||||
if original_node_id in node_errors:
|
||||
node_error_list = node_errors[original_node_id]
|
||||
if len(node_error_list) <= 10:
|
||||
node["recent_errors"] = node_error_list
|
||||
else:
|
||||
# First 5 + last 5 if more than 10 errors
|
||||
node["recent_errors"] = node_error_list[:5] + node_error_list[-5:]
|
||||
|
||||
# Add latest output samples (latest 3)
|
||||
if original_node_id in node_outputs:
|
||||
node_output_list = node_outputs[original_node_id]
|
||||
# Sort by timestamp if available, otherwise take last 3
|
||||
if node_output_list and node_output_list[0].get("timestamp"):
|
||||
node_output_list.sort(
|
||||
key=lambda x: x.get("timestamp", ""), reverse=True
|
||||
)
|
||||
node["recent_outputs"] = node_output_list[:3]
|
||||
|
||||
# Add latest input samples (latest 3)
|
||||
if original_node_id in node_inputs:
|
||||
node_input_list = node_inputs[original_node_id]
|
||||
# Sort by timestamp if available, otherwise take last 3
|
||||
if node_input_list and node_input_list[0].get("timestamp"):
|
||||
node_input_list.sort(
|
||||
key=lambda x: x.get("timestamp", ""), reverse=True
|
||||
)
|
||||
node["recent_inputs"] = node_input_list[:3]
|
||||
|
||||
# Build node relations from graph links
|
||||
node_relations: list[NodeRelation] = []
|
||||
for link in graph_links:
|
||||
# Include link details with source and sink information (truncated UUIDs)
|
||||
relation: NodeRelation = {
|
||||
"source_node_id": _truncate_uuid(link.source_id),
|
||||
"sink_node_id": _truncate_uuid(link.sink_id),
|
||||
"source_name": link.source_name,
|
||||
"sink_name": link.sink_name,
|
||||
"is_static": link.is_static if hasattr(link, "is_static") else False,
|
||||
}
|
||||
|
||||
# Add block names if nodes exist in our map
|
||||
if link.source_id in node_map:
|
||||
relation["source_block_name"] = node_map[link.source_id]["block_name"]
|
||||
if link.sink_id in node_map:
|
||||
relation["sink_block_name"] = node_map[link.sink_id]["block_name"]
|
||||
|
||||
node_relations.append(relation)
|
||||
|
||||
# Build overall summary
|
||||
return {
|
||||
"graph_info": {"name": graph_name, "description": graph_description},
|
||||
"nodes": nodes,
|
||||
"node_relations": node_relations,
|
||||
"input_output_data": input_output_data,
|
||||
"overall_status": {
|
||||
"total_nodes_in_graph": len(nodes),
|
||||
"total_executions": execution_stats.node_count,
|
||||
"total_errors": execution_stats.node_error_count,
|
||||
"execution_time_seconds": execution_stats.walltime,
|
||||
"has_errors": bool(
|
||||
execution_stats.error or execution_stats.node_error_count > 0
|
||||
),
|
||||
"graph_error": (
|
||||
str(execution_stats.error) if execution_stats.error else None
|
||||
),
|
||||
"graph_execution_status": (
|
||||
execution_status.value if execution_status else None
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@func_retry
|
||||
async def _call_llm_direct(
|
||||
credentials: APIKeyCredentials, prompt: list[dict[str, str]]
|
||||
) -> str:
|
||||
"""Make direct LLM call."""
|
||||
|
||||
response = await llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=LlmModel.GPT4O_MINI,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=150,
|
||||
compress_prompt_to_fit=True,
|
||||
)
|
||||
|
||||
if response and response.response:
|
||||
return response.response.strip()
|
||||
else:
|
||||
return "Unable to generate activity summary"
|
||||
@@ -0,0 +1,702 @@
|
||||
"""
|
||||
Tests for activity status generator functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.llm import LLMResponse
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.model import GraphExecutionStats
|
||||
from backend.executor.activity_status_generator import (
|
||||
_build_execution_summary,
|
||||
_call_llm_direct,
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_node_executions():
|
||||
"""Create mock node executions for testing."""
|
||||
return [
|
||||
NodeExecutionResult(
|
||||
user_id="test_user",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="123e4567-e89b-12d3-a456-426614174001",
|
||||
node_id="456e7890-e89b-12d3-a456-426614174002",
|
||||
block_id="789e1234-e89b-12d3-a456-426614174003",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
input_data={"user_input": "Hello, world!"},
|
||||
output_data={"processed_input": ["Hello, world!"]},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
NodeExecutionResult(
|
||||
user_id="test_user",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="234e5678-e89b-12d3-a456-426614174004",
|
||||
node_id="567e8901-e89b-12d3-a456-426614174005",
|
||||
block_id="890e2345-e89b-12d3-a456-426614174006",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
input_data={"data": "Hello, world!"},
|
||||
output_data={"result": ["Processed data"]},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
NodeExecutionResult(
|
||||
user_id="test_user",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="345e6789-e89b-12d3-a456-426614174007",
|
||||
node_id="678e9012-e89b-12d3-a456-426614174008",
|
||||
block_id="901e3456-e89b-12d3-a456-426614174009",
|
||||
status=ExecutionStatus.FAILED,
|
||||
input_data={"final_data": "Processed data"},
|
||||
output_data={
|
||||
"error": ["Connection timeout: Unable to reach external service"]
|
||||
},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_execution_stats():
|
||||
"""Create mock execution stats for testing."""
|
||||
return GraphExecutionStats(
|
||||
walltime=2.5,
|
||||
cputime=1.8,
|
||||
nodes_walltime=2.0,
|
||||
nodes_cputime=1.5,
|
||||
node_count=3,
|
||||
node_error_count=1,
|
||||
cost=10,
|
||||
error=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_execution_stats_with_graph_error():
|
||||
"""Create mock execution stats with graph-level error."""
|
||||
return GraphExecutionStats(
|
||||
walltime=2.5,
|
||||
cputime=1.8,
|
||||
nodes_walltime=2.0,
|
||||
nodes_cputime=1.5,
|
||||
node_count=3,
|
||||
node_error_count=1,
|
||||
cost=10,
|
||||
error="Graph execution failed: Invalid API credentials",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_blocks():
|
||||
"""Create mock blocks for testing."""
|
||||
input_block = MagicMock()
|
||||
input_block.name = "AgentInputBlock"
|
||||
input_block.description = "Handles user input"
|
||||
|
||||
process_block = MagicMock()
|
||||
process_block.name = "ProcessingBlock"
|
||||
process_block.description = "Processes data"
|
||||
|
||||
output_block = MagicMock()
|
||||
output_block.name = "AgentOutputBlock"
|
||||
output_block.description = "Provides output to user"
|
||||
|
||||
return {
|
||||
"789e1234-e89b-12d3-a456-426614174003": input_block,
|
||||
"890e2345-e89b-12d3-a456-426614174006": process_block,
|
||||
"901e3456-e89b-12d3-a456-426614174009": output_block,
|
||||
"process_block_id": process_block, # Keep old key for different error format test
|
||||
}
|
||||
|
||||
|
||||
class TestBuildExecutionSummary:
|
||||
"""Tests for _build_execution_summary function."""
|
||||
|
||||
def test_build_summary_with_successful_execution(
|
||||
self, mock_node_executions, mock_execution_stats, mock_blocks
|
||||
):
|
||||
"""Test building summary for successful execution."""
|
||||
# Create mock links with realistic UUIDs
|
||||
mock_links = [
|
||||
MagicMock(
|
||||
source_id="456e7890-e89b-12d3-a456-426614174002",
|
||||
sink_id="567e8901-e89b-12d3-a456-426614174005",
|
||||
source_name="output",
|
||||
sink_name="input",
|
||||
is_static=False,
|
||||
)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.get_block"
|
||||
) as mock_get_block:
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
|
||||
summary = _build_execution_summary(
|
||||
mock_node_executions[:2],
|
||||
mock_execution_stats,
|
||||
"Test Graph",
|
||||
"A test graph for processing",
|
||||
mock_links,
|
||||
ExecutionStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Check graph info
|
||||
assert summary["graph_info"]["name"] == "Test Graph"
|
||||
assert summary["graph_info"]["description"] == "A test graph for processing"
|
||||
|
||||
# Check nodes with per-node counts
|
||||
assert len(summary["nodes"]) == 2
|
||||
assert summary["nodes"][0]["block_name"] == "AgentInputBlock"
|
||||
assert summary["nodes"][0]["execution_count"] == 1
|
||||
assert summary["nodes"][0]["error_count"] == 0
|
||||
assert summary["nodes"][1]["block_name"] == "ProcessingBlock"
|
||||
assert summary["nodes"][1]["execution_count"] == 1
|
||||
assert summary["nodes"][1]["error_count"] == 0
|
||||
|
||||
# Check node relations (UUIDs are truncated to first segment)
|
||||
assert len(summary["node_relations"]) == 1
|
||||
assert (
|
||||
summary["node_relations"][0]["source_node_id"] == "456e7890"
|
||||
) # Truncated
|
||||
assert (
|
||||
summary["node_relations"][0]["sink_node_id"] == "567e8901"
|
||||
) # Truncated
|
||||
assert (
|
||||
summary["node_relations"][0]["source_block_name"] == "AgentInputBlock"
|
||||
)
|
||||
assert summary["node_relations"][0]["sink_block_name"] == "ProcessingBlock"
|
||||
|
||||
# Check overall status
|
||||
assert summary["overall_status"]["total_nodes_in_graph"] == 2
|
||||
assert summary["overall_status"]["total_executions"] == 3
|
||||
assert summary["overall_status"]["total_errors"] == 1
|
||||
assert summary["overall_status"]["execution_time_seconds"] == 2.5
|
||||
assert summary["overall_status"]["graph_execution_status"] == "COMPLETED"
|
||||
|
||||
# Check input/output data (using actual node UUIDs)
|
||||
assert (
|
||||
"456e7890-e89b-12d3-a456-426614174002_inputs"
|
||||
in summary["input_output_data"]
|
||||
)
|
||||
assert (
|
||||
"456e7890-e89b-12d3-a456-426614174002_outputs"
|
||||
in summary["input_output_data"]
|
||||
)
|
||||
|
||||
def test_build_summary_with_failed_execution(
|
||||
self, mock_node_executions, mock_execution_stats, mock_blocks
|
||||
):
|
||||
"""Test building summary for execution with failures."""
|
||||
mock_links = [] # No links for this test
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.get_block"
|
||||
) as mock_get_block:
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
|
||||
summary = _build_execution_summary(
|
||||
mock_node_executions,
|
||||
mock_execution_stats,
|
||||
"Failed Graph",
|
||||
"Test with failures",
|
||||
mock_links,
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
# Check that errors are now in node's recent_errors field
|
||||
# Find the output node (with truncated UUID)
|
||||
output_node = next(
|
||||
n for n in summary["nodes"] if n["node_id"] == "678e9012" # Truncated
|
||||
)
|
||||
assert output_node["error_count"] == 1
|
||||
assert output_node["execution_count"] == 1
|
||||
|
||||
# Check recent_errors field
|
||||
assert "recent_errors" in output_node
|
||||
assert len(output_node["recent_errors"]) == 1
|
||||
assert (
|
||||
output_node["recent_errors"][0]["error"]
|
||||
== "Connection timeout: Unable to reach external service"
|
||||
)
|
||||
assert (
|
||||
"execution_id" in output_node["recent_errors"][0]
|
||||
) # Should include execution ID
|
||||
|
||||
def test_build_summary_with_missing_blocks(
|
||||
self, mock_node_executions, mock_execution_stats
|
||||
):
|
||||
"""Test building summary when blocks are missing."""
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.get_block"
|
||||
) as mock_get_block:
|
||||
mock_get_block.return_value = None
|
||||
|
||||
summary = _build_execution_summary(
|
||||
mock_node_executions,
|
||||
mock_execution_stats,
|
||||
"Missing Blocks Graph",
|
||||
"Test with missing blocks",
|
||||
[],
|
||||
ExecutionStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Should handle missing blocks gracefully
|
||||
assert len(summary["nodes"]) == 0
|
||||
# No top-level errors field anymore, errors are in nodes' recent_errors
|
||||
assert summary["graph_info"]["name"] == "Missing Blocks Graph"
|
||||
|
||||
def test_build_summary_with_graph_error(
|
||||
self, mock_node_executions, mock_execution_stats_with_graph_error, mock_blocks
|
||||
):
|
||||
"""Test building summary with graph-level error."""
|
||||
mock_links = []
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.get_block"
|
||||
) as mock_get_block:
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
|
||||
summary = _build_execution_summary(
|
||||
mock_node_executions,
|
||||
mock_execution_stats_with_graph_error,
|
||||
"Graph with Error",
|
||||
"Test with graph error",
|
||||
mock_links,
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
# Check that graph error is included in overall status
|
||||
assert summary["overall_status"]["has_errors"] is True
|
||||
assert (
|
||||
summary["overall_status"]["graph_error"]
|
||||
== "Graph execution failed: Invalid API credentials"
|
||||
)
|
||||
assert summary["overall_status"]["total_errors"] == 1
|
||||
assert summary["overall_status"]["graph_execution_status"] == "FAILED"
|
||||
|
||||
def test_build_summary_with_different_error_formats(
|
||||
self, mock_execution_stats, mock_blocks
|
||||
):
|
||||
"""Test building summary with different error formats."""
|
||||
# Create node executions with different error formats and realistic UUIDs
|
||||
mock_executions = [
|
||||
NodeExecutionResult(
|
||||
user_id="test_user",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="111e2222-e89b-12d3-a456-426614174010",
|
||||
node_id="333e4444-e89b-12d3-a456-426614174011",
|
||||
block_id="process_block_id",
|
||||
status=ExecutionStatus.FAILED,
|
||||
input_data={},
|
||||
output_data={"error": ["Simple string error message"]},
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
NodeExecutionResult(
|
||||
user_id="test_user",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec",
|
||||
node_exec_id="555e6666-e89b-12d3-a456-426614174012",
|
||||
node_id="777e8888-e89b-12d3-a456-426614174013",
|
||||
block_id="process_block_id",
|
||||
status=ExecutionStatus.FAILED,
|
||||
input_data={},
|
||||
output_data={}, # No error in output
|
||||
add_time=datetime.now(timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.get_block"
|
||||
) as mock_get_block:
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
|
||||
summary = _build_execution_summary(
|
||||
mock_executions,
|
||||
mock_execution_stats,
|
||||
"Error Test Graph",
|
||||
"Testing error formats",
|
||||
[],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
# Check different error formats - errors are now in nodes' recent_errors
|
||||
error_nodes = [n for n in summary["nodes"] if n.get("recent_errors")]
|
||||
assert len(error_nodes) == 2
|
||||
|
||||
# String error format - find node with truncated ID
|
||||
string_error_node = next(
|
||||
n for n in summary["nodes"] if n["node_id"] == "333e4444" # Truncated
|
||||
)
|
||||
assert len(string_error_node["recent_errors"]) == 1
|
||||
assert (
|
||||
string_error_node["recent_errors"][0]["error"]
|
||||
== "Simple string error message"
|
||||
)
|
||||
|
||||
# No error output format - find node with truncated ID
|
||||
no_error_node = next(
|
||||
n for n in summary["nodes"] if n["node_id"] == "777e8888" # Truncated
|
||||
)
|
||||
assert len(no_error_node["recent_errors"]) == 1
|
||||
assert no_error_node["recent_errors"][0]["error"] == "Unknown error"
|
||||
|
||||
|
||||
class TestLLMCall:
|
||||
"""Tests for LLM calling functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_llm_direct_success(self):
|
||||
"""Test successful LLM call."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
mock_response = LLMResponse(
|
||||
raw_response={},
|
||||
prompt=[],
|
||||
response="Agent successfully processed user input and generated response.",
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=20,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.llm_call"
|
||||
) as mock_llm_call:
|
||||
mock_llm_call.return_value = mock_response
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
id="test",
|
||||
provider="openai",
|
||||
api_key=SecretStr("test_key"),
|
||||
title="Test",
|
||||
)
|
||||
|
||||
prompt = [{"role": "user", "content": "Test prompt"}]
|
||||
|
||||
result = await _call_llm_direct(credentials, prompt)
|
||||
|
||||
assert (
|
||||
result
|
||||
== "Agent successfully processed user input and generated response."
|
||||
)
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_llm_direct_no_response(self):
|
||||
"""Test LLM call with no response."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.llm_call"
|
||||
) as mock_llm_call:
|
||||
mock_llm_call.return_value = None
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
id="test",
|
||||
provider="openai",
|
||||
api_key=SecretStr("test_key"),
|
||||
title="Test",
|
||||
)
|
||||
|
||||
prompt = [{"role": "user", "content": "Test prompt"}]
|
||||
|
||||
result = await _call_llm_direct(credentials, prompt)
|
||||
|
||||
assert result == "Unable to generate activity summary"
|
||||
|
||||
|
||||
class TestGenerateActivityStatusForExecution:
|
||||
"""Tests for the main generate_activity_status_for_execution function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_status_success(
|
||||
self, mock_node_executions, mock_execution_stats, mock_blocks
|
||||
):
|
||||
"""Test successful activity status generation."""
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node_executions.return_value = mock_node_executions
|
||||
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_graph_metadata.description = "A test agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.links = []
|
||||
mock_db_client.get_graph.return_value = mock_graph
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.get_block"
|
||||
) as mock_get_block, patch(
|
||||
"backend.executor.activity_status_generator.Settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.activity_status_generator._call_llm_direct"
|
||||
) as mock_llm, patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_llm.return_value = (
|
||||
"I analyzed your data and provided the requested insights."
|
||||
)
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
execution_stats=mock_execution_stats,
|
||||
db_client=mock_db_client,
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
assert result == "I analyzed your data and provided the requested insights."
|
||||
mock_db_client.get_node_executions.assert_called_once()
|
||||
mock_db_client.get_graph_metadata.assert_called_once()
|
||||
mock_db_client.get_graph.assert_called_once()
|
||||
mock_llm.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_status_feature_disabled(self, mock_execution_stats):
|
||||
"""Test activity status generation when feature is disabled."""
|
||||
mock_db_client = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=False,
|
||||
):
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
execution_stats=mock_execution_stats,
|
||||
db_client=mock_db_client,
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_db_client.get_node_executions.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_status_no_api_key(self, mock_execution_stats):
|
||||
"""Test activity status generation with no API key."""
|
||||
mock_db_client = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.Settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_api_key = ""
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
execution_stats=mock_execution_stats,
|
||||
db_client=mock_db_client,
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_db_client.get_node_executions.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_status_exception_handling(self, mock_execution_stats):
|
||||
"""Test activity status generation with exception."""
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node_executions.side_effect = Exception("Database error")
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.Settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
execution_stats=mock_execution_stats,
|
||||
db_client=mock_db_client,
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_status_with_graph_name_fallback(
|
||||
self, mock_node_executions, mock_execution_stats, mock_blocks
|
||||
):
|
||||
"""Test activity status generation with graph name fallback."""
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node_executions.return_value = mock_node_executions
|
||||
mock_db_client.get_graph_metadata.return_value = None # No metadata
|
||||
mock_db_client.get_graph.return_value = None # No graph
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.get_block"
|
||||
) as mock_get_block, patch(
|
||||
"backend.executor.activity_status_generator.Settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.activity_status_generator._call_llm_direct"
|
||||
) as mock_llm, patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_llm.return_value = "Agent completed execution."
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
execution_stats=mock_execution_stats,
|
||||
db_client=mock_db_client,
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
assert result == "Agent completed execution."
|
||||
# Should use fallback graph name in prompt
|
||||
call_args = mock_llm.call_args[0][1] # prompt argument
|
||||
assert "Graph test_graph" in call_args[1]["content"]
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests to verify the complete flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_integration_flow(
|
||||
self, mock_node_executions, mock_execution_stats, mock_blocks
|
||||
):
|
||||
"""Test the complete integration flow."""
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node_executions.return_value = mock_node_executions
|
||||
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Integration Agent"
|
||||
mock_graph_metadata.description = "Integration test agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.links = []
|
||||
mock_db_client.get_graph.return_value = mock_graph
|
||||
|
||||
expected_activity = "I processed user input but failed during final output generation due to system error."
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.get_block"
|
||||
) as mock_get_block, patch(
|
||||
"backend.executor.activity_status_generator.Settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.activity_status_generator.llm_call"
|
||||
) as mock_llm_call, patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
|
||||
mock_response = LLMResponse(
|
||||
raw_response={},
|
||||
prompt=[],
|
||||
response=expected_activity,
|
||||
tool_calls=None,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=30,
|
||||
)
|
||||
mock_llm_call.return_value = mock_response
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
execution_stats=mock_execution_stats,
|
||||
db_client=mock_db_client,
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
assert result == expected_activity
|
||||
|
||||
# Verify the correct data was passed to LLM
|
||||
llm_call_args = mock_llm_call.call_args
|
||||
prompt = llm_call_args[1]["prompt"]
|
||||
|
||||
# Check system prompt
|
||||
assert prompt[0]["role"] == "system"
|
||||
assert "user's perspective" in prompt[0]["content"]
|
||||
|
||||
# Check user prompt contains expected data
|
||||
user_content = prompt[1]["content"]
|
||||
assert "Test Integration Agent" in user_content
|
||||
assert "user-friendly terms" in user_content.lower()
|
||||
|
||||
# Verify that execution data is present in the prompt
|
||||
assert "{" in user_content # Should contain JSON data
|
||||
assert "overall_status" in user_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_integration_with_disabled_feature(
|
||||
self, mock_execution_stats
|
||||
):
|
||||
"""Test that when feature returns None, manager doesn't set activity_status."""
|
||||
mock_db_client = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=False,
|
||||
):
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
execution_stats=mock_execution_stats,
|
||||
db_client=mock_db_client,
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
# Should return None when disabled
|
||||
assert result is None
|
||||
|
||||
# Verify no database calls were made
|
||||
mock_db_client.get_node_executions.assert_not_called()
|
||||
mock_db_client.get_graph_metadata.assert_not_called()
|
||||
mock_db_client.get_graph.assert_not_called()
|
||||
@@ -7,7 +7,6 @@ from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
get_block_error_stats,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_latest_node_execution,
|
||||
@@ -16,12 +15,12 @@ from backend.data.execution import (
|
||||
set_execution_kv_data,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
update_node_execution_status,
|
||||
update_node_execution_status_batch,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.generate_data import get_user_execution_summary_data
|
||||
from backend.data.graph import (
|
||||
get_connected_output_nodes,
|
||||
get_graph,
|
||||
@@ -40,12 +39,16 @@ from backend.data.user import (
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
get_user_metadata,
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
update_user_metadata,
|
||||
)
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
UnhealthyServiceError,
|
||||
endpoint_to_sync,
|
||||
expose,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
config = Config()
|
||||
@@ -77,6 +80,11 @@ class DatabaseManager(AppService):
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
|
||||
async def health_check(self) -> str:
|
||||
if not db.is_connected():
|
||||
raise UnhealthyServiceError("Database is not connected")
|
||||
return await super().health_check()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return config.database_api_port
|
||||
@@ -90,7 +98,6 @@ class DatabaseManager(AppService):
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# Executions
|
||||
get_graph_execution = _(get_graph_execution)
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
@@ -101,7 +108,6 @@ class DatabaseManager(AppService):
|
||||
update_node_execution_status_batch = _(update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||
update_node_execution_stats = _(update_node_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
@@ -119,8 +125,6 @@ class DatabaseManager(AppService):
|
||||
get_credits = _(_get_credits, name="get_credits")
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = _(get_user_metadata)
|
||||
update_user_metadata = _(update_user_metadata)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
@@ -141,6 +145,9 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -151,55 +158,23 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
return DatabaseManager
|
||||
|
||||
# Executions
|
||||
get_graph_execution = _(d.get_graph_execution)
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
create_graph_execution = _(d.create_graph_execution)
|
||||
get_node_execution = _(d.get_node_execution)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
get_latest_node_execution = _(d.get_latest_node_execution)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
|
||||
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
|
||||
update_graph_execution_stats = _(d.update_graph_execution_stats)
|
||||
update_node_execution_stats = _(d.update_node_execution_stats)
|
||||
upsert_execution_input = _(d.upsert_execution_input)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
get_execution_kv_data = _(d.get_execution_kv_data)
|
||||
set_execution_kv_data = _(d.set_execution_kv_data)
|
||||
|
||||
# Graphs
|
||||
get_node = _(d.get_node)
|
||||
get_graph = _(d.get_graph)
|
||||
get_connected_output_nodes = _(d.get_connected_output_nodes)
|
||||
get_graph_metadata = _(d.get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = _(d.spend_credits)
|
||||
get_credits = _(d.get_credits)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = _(d.get_user_metadata)
|
||||
update_user_metadata = _(d.update_user_metadata)
|
||||
get_user_integrations = _(d.get_user_integrations)
|
||||
update_user_integrations = _(d.update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
|
||||
get_user_email_by_id = _(d.get_user_email_by_id)
|
||||
get_user_email_verification = _(d.get_user_email_verification)
|
||||
get_user_notification_preference = _(d.get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(d.empty_user_notification_batch)
|
||||
get_all_batches_by_type = _(d.get_all_batches_by_type)
|
||||
get_user_notification_batch = _(d.get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(d.get_user_execution_summary_data)
|
||||
|
||||
# Block error monitoring
|
||||
get_block_error_stats = _(d.get_block_error_stats)
|
||||
@@ -225,10 +200,28 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_stats = d.update_node_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
update_user_integrations = d.update_user_integrations
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
get_block_error_stats = d.get_block_error_stats
|
||||
|
||||
# User Comms
|
||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||
get_user_email_by_id = d.get_user_email_by_id
|
||||
get_user_email_verification = d.get_user_email_verification
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
# Notifications
|
||||
create_or_add_to_user_notification_batch = (
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = d.empty_user_notification_batch
|
||||
get_all_batches_by_type = d.get_all_batches_by_type
|
||||
get_user_notification_batch = d.get_user_notification_batch
|
||||
get_user_notification_oldest_message_in_batch = (
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Summary data
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,6 @@ import logging
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
@@ -12,6 +11,7 @@ from backend.blocks.data_manipulation import FindInDictionaryBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.data.model import User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
|
||||
from apscheduler.events import (
|
||||
EVENT_JOB_ERROR,
|
||||
EVENT_JOB_EXECUTED,
|
||||
EVENT_JOB_MAX_INSTANCES,
|
||||
EVENT_JOB_MISSED,
|
||||
)
|
||||
from apscheduler.job import Job as JobObj
|
||||
from apscheduler.jobstores.memory import MemoryJobStore
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
@@ -30,7 +35,14 @@ from backend.monitoring import (
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.logging import PrefixFilter
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
UnhealthyServiceError,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -60,26 +72,69 @@ apscheduler_logger.addFilter(PrefixFilter("[Scheduler] [APScheduler]"))
|
||||
|
||||
config = Config()
|
||||
|
||||
# Timeout constants
|
||||
SCHEDULER_OPERATION_TIMEOUT_SECONDS = 300 # 5 minutes for scheduler operations
|
||||
|
||||
|
||||
def job_listener(event):
|
||||
"""Logs job execution outcomes for better monitoring."""
|
||||
if event.exception:
|
||||
logger.error(f"Job {event.job_id} failed.")
|
||||
logger.error(
|
||||
f"Job {event.job_id} failed: {type(event.exception).__name__}: {event.exception}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
def job_missed_listener(event):
|
||||
"""Logs when jobs are missed due to scheduling issues."""
|
||||
logger.warning(
|
||||
f"Job {event.job_id} was missed at scheduled time {event.scheduled_run_time}. "
|
||||
f"This can happen if the scheduler is overloaded or if previous executions are still running."
|
||||
)
|
||||
|
||||
|
||||
def job_max_instances_listener(event):
|
||||
"""Logs when jobs hit max instances limit."""
|
||||
logger.warning(
|
||||
f"Job {event.job_id} execution was SKIPPED - max instances limit reached. "
|
||||
f"Previous execution(s) are still running. "
|
||||
f"Consider increasing max_instances or check why previous executions are taking too long."
|
||||
)
|
||||
|
||||
|
||||
_event_loop: asyncio.AbstractEventLoop | None = None
|
||||
_event_loop_thread: threading.Thread | None = None
|
||||
|
||||
|
||||
@func_retry
|
||||
def get_event_loop():
|
||||
return asyncio.new_event_loop()
|
||||
"""Get the shared event loop."""
|
||||
if _event_loop is None:
|
||||
raise RuntimeError("Event loop not initialized. Scheduler not started.")
|
||||
return _event_loop
|
||||
|
||||
|
||||
def run_async(coro, timeout: float = SCHEDULER_OPERATION_TIMEOUT_SECONDS):
|
||||
"""Run a coroutine in the shared event loop and wait for completion."""
|
||||
loop = get_event_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Async operation failed: {type(e).__name__}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def execute_graph(**kwargs):
|
||||
get_event_loop().run_until_complete(_execute_graph(**kwargs))
|
||||
"""Execute graph in the shared event loop and wait for completion."""
|
||||
# Wait for completion to ensure job doesn't exit prematurely
|
||||
run_async(_execute_graph(**kwargs))
|
||||
|
||||
|
||||
async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
try:
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||
@@ -89,16 +144,28 @@ async def _execute_graph(**kwargs):
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
)
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||
f"(took {elapsed:.2f}s to create and publish)"
|
||||
)
|
||||
if elapsed > 10:
|
||||
logger.warning(
|
||||
f"Graph execution {graph_exec.id} took {elapsed:.2f}s to create/publish - "
|
||||
f"this is unusually slow and may indicate resource contention"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing graph {args.graph_id}: {e}")
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.error(
|
||||
f"Error executing graph {args.graph_id} after {elapsed:.2f}s: "
|
||||
f"{type(e).__name__}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def cleanup_expired_files():
|
||||
"""Clean up expired files from cloud storage."""
|
||||
get_event_loop().run_until_complete(cleanup_expired_files_async())
|
||||
# Wait for completion
|
||||
run_async(cleanup_expired_files_async())
|
||||
|
||||
|
||||
# Monitoring functions are now imported from monitoring module
|
||||
@@ -154,7 +221,7 @@ class NotificationJobInfo(NotificationJobArgs):
|
||||
|
||||
|
||||
class Scheduler(AppService):
|
||||
scheduler: BlockingScheduler
|
||||
scheduler: BackgroundScheduler
|
||||
|
||||
def __init__(self, register_system_tasks: bool = True):
|
||||
self.register_system_tasks = register_system_tasks
|
||||
@@ -167,10 +234,50 @@ class Scheduler(AppService):
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
async def health_check(self) -> str:
|
||||
# Thread-safe health check with proper initialization handling
|
||||
if not hasattr(self, "scheduler"):
|
||||
raise UnhealthyServiceError("Scheduler is still initializing")
|
||||
|
||||
# Check if we're in the middle of cleanup
|
||||
if self.cleaned_up:
|
||||
return await super().health_check()
|
||||
|
||||
# Normal operation - check if scheduler is running
|
||||
if not self.scheduler.running:
|
||||
raise UnhealthyServiceError("Scheduler is not running")
|
||||
|
||||
return await super().health_check()
|
||||
|
||||
def run_service(self):
|
||||
load_dotenv()
|
||||
|
||||
# Initialize the event loop for async jobs
|
||||
global _event_loop
|
||||
_event_loop = asyncio.new_event_loop()
|
||||
|
||||
# Use daemon thread since it should die with the main service
|
||||
global _event_loop_thread
|
||||
_event_loop_thread = threading.Thread(
|
||||
target=_event_loop.run_forever, daemon=True, name="SchedulerEventLoop"
|
||||
)
|
||||
_event_loop_thread.start()
|
||||
|
||||
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
|
||||
self.scheduler = BlockingScheduler(
|
||||
# Configure executors to limit concurrency without skipping jobs
|
||||
from apscheduler.executors.pool import ThreadPoolExecutor
|
||||
|
||||
self.scheduler = BackgroundScheduler(
|
||||
executors={
|
||||
"default": ThreadPoolExecutor(
|
||||
max_workers=self.db_pool_size()
|
||||
), # Match DB pool size to prevent resource contention
|
||||
},
|
||||
job_defaults={
|
||||
"coalesce": True, # Skip redundant missed jobs - just run the latest
|
||||
"max_instances": 1000, # Effectively unlimited - never drop executions
|
||||
"misfire_grace_time": None, # No time limit for missed jobs
|
||||
},
|
||||
jobstores={
|
||||
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
|
||||
engine=create_engine(
|
||||
@@ -200,9 +307,10 @@ class Scheduler(AppService):
|
||||
|
||||
if self.register_system_tasks:
|
||||
# Notification PROCESS WEEKLY SUMMARY
|
||||
# Runs every Monday at 9 AM UTC
|
||||
self.scheduler.add_job(
|
||||
process_weekly_summary,
|
||||
CronTrigger.from_crontab("0 * * * *"),
|
||||
CronTrigger.from_crontab("0 9 * * 1"),
|
||||
id="process_weekly_summary",
|
||||
kwargs={},
|
||||
replace_existing=True,
|
||||
@@ -250,13 +358,30 @@ class Scheduler(AppService):
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||
self.scheduler.start()
|
||||
|
||||
# Keep the service running since BackgroundScheduler doesn't block
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
self.scheduler.shutdown(wait=True)
|
||||
|
||||
global _event_loop
|
||||
if _event_loop:
|
||||
logger.info("⏳ Closing event loop...")
|
||||
_event_loop.call_soon_threadsafe(_event_loop.stop)
|
||||
|
||||
global _event_loop_thread
|
||||
if _event_loop_thread:
|
||||
logger.info("⏳ Waiting for event loop thread to finish...")
|
||||
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
logger.info("Scheduler cleanup complete.")
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
@@ -269,6 +394,18 @@ class Scheduler(AppService):
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
) -> GraphExecutionJobInfo:
|
||||
# Validate the graph before scheduling to prevent runtime failures
|
||||
# We don't need the return value, just want the validation to run
|
||||
run_async(
|
||||
execution_utils.validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
graph_inputs=input_data,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=input_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from backend.data import db
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@@ -17,7 +16,7 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
scheduler = get_service_client(SchedulerClient)
|
||||
scheduler = get_scheduler_client()
|
||||
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
|
||||
@@ -4,52 +4,36 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel, JsonValue
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.db import prisma
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionStatus,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
RedisExecutionEventBus,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
get_database_manager_async_client,
|
||||
get_integration_credentials_store,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
config = Config()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||
|
||||
@@ -86,51 +70,6 @@ class LogMetadata(TruncatedLogger):
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_execution_queue_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_execution_queue() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_execution_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
return IntegrationCredentialsStore()
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
from backend.executor import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_async_client() -> "DatabaseManagerAsyncClient":
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient)
|
||||
|
||||
|
||||
# ============ Execution Cost Helpers ============ #
|
||||
|
||||
|
||||
@@ -457,7 +396,7 @@ def validate_exec(
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := schema.get_mismatch_error(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
logger.warning(error_message)
|
||||
return None, error_message
|
||||
|
||||
return data, node_block.name
|
||||
@@ -467,47 +406,65 @@ async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
|
||||
"""
|
||||
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
credentials_fields = block.input_schema.get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
if (
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Credentials absent for {block.name} node #{node.id} "
|
||||
f"input '{field_name}'"
|
||||
)
|
||||
try:
|
||||
if (
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
# Missing credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
except ValidationError as e:
|
||||
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
||||
continue
|
||||
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = await get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
try:
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = await get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any errors fetching credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
continue
|
||||
|
||||
if not credentials:
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Unknown credentials #{credentials_meta.id}"
|
||||
continue
|
||||
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
@@ -518,10 +475,12 @@ async def _validate_node_input_credentials(
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
return credential_errors
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
@@ -559,7 +518,37 @@ def make_node_credentials_input_map(
|
||||
return result
|
||||
|
||||
|
||||
async def construct_node_execution_input(
|
||||
async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node
|
||||
"""
|
||||
# Get input validation errors
|
||||
node_input_errors = GraphModel.validate_graph_get_errors(
|
||||
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
||||
)
|
||||
|
||||
# Get credential input/availability/validation errors
|
||||
node_credential_input_errors = await _validate_node_input_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
|
||||
# Merge credential errors with structural errors
|
||||
for node_id, field_errors in node_credential_input_errors.items():
|
||||
if node_id not in node_input_errors:
|
||||
node_input_errors[node_id] = {}
|
||||
node_input_errors[node_id].update(field_errors)
|
||||
|
||||
return node_input_errors
|
||||
|
||||
|
||||
async def _construct_starting_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
@@ -581,8 +570,17 @@ async def construct_node_execution_input(
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True, nodes_input_masks=nodes_input_masks)
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
# Use new validation function that includes credentials
|
||||
validation_errors = await validate_graph_with_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
n_error_nodes = len(validation_errors)
|
||||
n_errors = sum(len(errors) for errors in validation_errors.values())
|
||||
if validation_errors:
|
||||
raise GraphValidationError(
|
||||
f"Graph validation failed: {n_errors} issues on {n_error_nodes} nodes",
|
||||
node_errors=validation_errors,
|
||||
)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
@@ -617,6 +615,67 @@ async def construct_node_execution_input(
|
||||
return nodes_input
|
||||
|
||||
|
||||
async def validate_and_construct_node_execution_input(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], dict[str, dict[str, JsonValue]]]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
|
||||
Args:
|
||||
graph_id: The ID of the graph to validate/construct.
|
||||
user_id: The ID of the user.
|
||||
graph_inputs: The input data for the graph execution.
|
||||
graph_version: The version of the graph to use.
|
||||
graph_credentials_inputs: Credentials inputs to use.
|
||||
nodes_input_masks: Node inputs to use.
|
||||
|
||||
Returns:
|
||||
tuple[GraphModel, list[tuple[str, BlockInput]]]: Graph model and list of tuples for node execution input.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
GraphValidationError: If the graph has validation issues.
|
||||
ValueError: If there are other validation errors.
|
||||
"""
|
||||
if prisma.is_connected():
|
||||
gdb = graph_db
|
||||
else:
|
||||
gdb = get_database_manager_async_client()
|
||||
|
||||
graph: GraphModel | None = await gdb.get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
nodes_input_masks = _merge_nodes_input_masks(
|
||||
(
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else {}
|
||||
),
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
|
||||
starting_nodes_input = await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
return graph, starting_nodes_input, nodes_input_masks
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
overrides_map_1: dict[str, dict[str, JsonValue]],
|
||||
overrides_map_2: dict[str, dict[str, JsonValue]],
|
||||
@@ -633,11 +692,6 @@ def _merge_nodes_input_masks(
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
GRAPH_EXECUTION_EXCHANGE = Exchange(
|
||||
name="graph_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
@@ -655,6 +709,11 @@ GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
|
||||
)
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
|
||||
|
||||
# Graceful shutdown timeout constants
|
||||
# Agent executions can run for up to 1 day, so we need a graceful shutdown period
|
||||
# that allows long-running executions to complete naturally
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 24 * 60 * 60 # 1 day to complete active executions
|
||||
|
||||
|
||||
def create_execution_queue_config() -> RabbitMQConfig:
|
||||
"""
|
||||
@@ -669,13 +728,14 @@ def create_execution_queue_config() -> RabbitMQConfig:
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={
|
||||
# x-consumer-timeout (0 = disabled)
|
||||
# x-consumer-timeout (1 week)
|
||||
# Problem: Default 30-minute consumer timeout kills long-running graph executions
|
||||
# Original error: "Consumer acknowledgement timed out after 1800000 ms (30 minutes)"
|
||||
# Solution: Disable consumer timeout entirely - let graphs run indefinitely
|
||||
# Safety: Heartbeat mechanism now handles dead consumer detection instead
|
||||
# Use case: Graph executions that take hours to complete (AI model training, etc.)
|
||||
"x-consumer-timeout": 0,
|
||||
"x-consumer-timeout": GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
||||
* 1000,
|
||||
},
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
@@ -692,6 +752,10 @@ def create_execution_queue_config() -> RabbitMQConfig:
|
||||
)
|
||||
|
||||
|
||||
class CancelExecutionEvent(BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
async def stop_graph_execution(
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
@@ -705,7 +769,7 @@ async def stop_graph_execution(
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
queue_client = await get_async_execution_queue()
|
||||
db = execution_db if prisma.is_connected() else get_db_async_client()
|
||||
db = execution_db if prisma.is_connected() else get_database_manager_async_client()
|
||||
await queue_client.publish_message(
|
||||
routing_key="",
|
||||
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
|
||||
@@ -737,51 +801,28 @@ async def stop_graph_execution(
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
]:
|
||||
break
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
graph_exec.status = ExecutionStatus.TERMINATED
|
||||
|
||||
await asyncio.gather(
|
||||
# Update graph execution status
|
||||
db.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.TERMINATED,
|
||||
),
|
||||
# Publish graph execution event
|
||||
get_async_execution_event_bus().publish(graph_exec),
|
||||
)
|
||||
return
|
||||
|
||||
if graph_exec.status == ExecutionStatus.RUNNING:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Set the termination status if the graph is not stopped after the timeout.
|
||||
if graph_exec := await db.get_graph_execution_meta(
|
||||
execution_id=graph_exec_id, user_id=user_id
|
||||
):
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
node_execs = await db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
include_exec_data=False,
|
||||
)
|
||||
|
||||
graph_exec.status = ExecutionStatus.TERMINATED
|
||||
for node_exec in node_execs:
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
|
||||
await asyncio.gather(
|
||||
# Update node execution statuses
|
||||
db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
),
|
||||
# Publish node execution events
|
||||
*[
|
||||
get_async_execution_event_bus().publish(node_exec)
|
||||
for node_exec in node_execs
|
||||
],
|
||||
)
|
||||
await asyncio.gather(
|
||||
# Update graph execution status
|
||||
db.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec_id,
|
||||
status=ExecutionStatus.TERMINATED,
|
||||
),
|
||||
# Publish graph execution event
|
||||
get_async_execution_event_bus().publish(graph_exec),
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Graph execution #{graph_exec_id} will need to take longer than {wait_timeout} seconds to stop. "
|
||||
f"You can check the status of the execution in the UI or try again later."
|
||||
)
|
||||
|
||||
|
||||
async def add_graph_execution(
|
||||
@@ -811,61 +852,62 @@ async def add_graph_execution(
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
"""
|
||||
if prisma.is_connected():
|
||||
gdb = graph_db
|
||||
edb = execution_db
|
||||
else:
|
||||
gdb = get_db_async_client()
|
||||
edb = get_db_async_client()
|
||||
edb = get_database_manager_async_client()
|
||||
|
||||
graph: GraphModel | None = await gdb.get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
nodes_input_masks = _merge_nodes_input_masks(
|
||||
(
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else {}
|
||||
),
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
starting_nodes_input = await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs or {},
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
graph_exec = await edb.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
graph, starting_nodes_input, nodes_input_masks = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs or {},
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
)
|
||||
graph_exec = None
|
||||
|
||||
try:
|
||||
graph_exec = await edb.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
|
||||
logger.info(
|
||||
f"Created graph execution #{graph_exec.id} for graph "
|
||||
f"#{graph_id} with {len(starting_nodes_input)} starting nodes. "
|
||||
f"Now publishing to execution queue."
|
||||
)
|
||||
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
bus = get_async_execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
except BaseException as e:
|
||||
err = str(e) or type(e).__name__
|
||||
if not graph_exec:
|
||||
logger.error(f"Unable to execute graph #{graph_id} failed: {err}")
|
||||
raise
|
||||
|
||||
logger.error(
|
||||
f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {err}"
|
||||
)
|
||||
await edb.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
@@ -873,7 +915,7 @@ async def add_graph_execution(
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
stats=GraphExecutionStats(error=err),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -888,13 +930,9 @@ class ExecutionOutputEntry(BaseModel):
|
||||
|
||||
|
||||
class NodeExecutionProgress:
|
||||
def __init__(
|
||||
self,
|
||||
on_done_task: Callable[[str, object], None],
|
||||
):
|
||||
def __init__(self):
|
||||
self.output: dict[str, list[ExecutionOutputEntry]] = defaultdict(list)
|
||||
self.tasks: dict[str, Future] = {}
|
||||
self.on_done_task = on_done_task
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_task(self, node_exec_id: str, task: Future):
|
||||
@@ -934,7 +972,9 @@ class NodeExecutionProgress:
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
|
||||
logger.error(
|
||||
f"Task for exec ID {exec_id} failed with error: {e.__class__.__name__} {str(e)}"
|
||||
)
|
||||
pass
|
||||
return self.is_done(0)
|
||||
|
||||
@@ -952,7 +992,7 @@ class NodeExecutionProgress:
|
||||
cancelled_ids.append(task_id)
|
||||
return cancelled_ids
|
||||
|
||||
def wait_for_cancellation(self, timeout: float = 5.0):
|
||||
def wait_for_done(self, timeout: float = 5.0):
|
||||
"""
|
||||
Wait for all cancelled tasks to complete cancellation.
|
||||
|
||||
@@ -962,9 +1002,12 @@ class NodeExecutionProgress:
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Check if all tasks are done (either completed or cancelled)
|
||||
if all(task.done() for task in self.tasks.values()):
|
||||
return True
|
||||
while self.pop_output():
|
||||
pass
|
||||
|
||||
if self.is_done():
|
||||
return
|
||||
|
||||
time.sleep(0.1) # Small delay to avoid busy waiting
|
||||
|
||||
raise TimeoutError(
|
||||
@@ -983,11 +1026,7 @@ class NodeExecutionProgress:
|
||||
if self.output[exec_id]:
|
||||
return False
|
||||
|
||||
if task := self.tasks.pop(exec_id):
|
||||
try:
|
||||
self.on_done_task(exec_id, task.result())
|
||||
except Exception as e:
|
||||
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
|
||||
self.tasks.pop(exec_id)
|
||||
return True
|
||||
|
||||
def _next_exec(self) -> str | None:
|
||||
|
||||
@@ -5,7 +5,6 @@ from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -183,6 +182,15 @@ zerobounce_credentials = APIKeyCredentials(
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
enrichlayer_credentials = APIKeyCredentials(
|
||||
id="d9fce73a-6c1d-4e8b-ba2e-12a456789def",
|
||||
provider="enrichlayer",
|
||||
api_key=SecretStr(settings.secrets.enrichlayer_api_key),
|
||||
title="Use Credits for Enrichlayer",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
llama_api_credentials = APIKeyCredentials(
|
||||
id="d44045af-1c33-4833-9e19-752313214de2",
|
||||
provider="llama_api",
|
||||
@@ -191,6 +199,14 @@ llama_api_credentials = APIKeyCredentials(
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
v0_credentials = APIKeyCredentials(
|
||||
id="c4e6d1a0-3b5f-4789-a8e2-9b123456789f",
|
||||
provider="v0",
|
||||
api_key=SecretStr(settings.secrets.v0_api_key),
|
||||
title="Use Credits for v0 by Vercel",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
DEFAULT_CREDENTIALS = [
|
||||
ollama_credentials,
|
||||
revid_credentials,
|
||||
@@ -204,6 +220,7 @@ DEFAULT_CREDENTIALS = [
|
||||
jina_credentials,
|
||||
unreal_credentials,
|
||||
open_router_credentials,
|
||||
enrichlayer_credentials,
|
||||
fal_credentials,
|
||||
exa_credentials,
|
||||
e2b_credentials,
|
||||
@@ -214,6 +231,8 @@ DEFAULT_CREDENTIALS = [
|
||||
smartlead_credentials,
|
||||
zerobounce_credentials,
|
||||
google_maps_credentials,
|
||||
llama_api_credentials,
|
||||
v0_credentials,
|
||||
]
|
||||
|
||||
|
||||
@@ -229,17 +248,15 @@ class IntegrationCredentialsStore:
|
||||
return self._locks
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def db_manager(self):
|
||||
if prisma.is_connected():
|
||||
from backend.data import user
|
||||
|
||||
return user
|
||||
else:
|
||||
from backend.executor.database import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient)
|
||||
return get_database_manager_async_client()
|
||||
|
||||
# =============== USER-MANAGED CREDENTIALS =============== #
|
||||
async def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
@@ -282,6 +299,8 @@ class IntegrationCredentialsStore:
|
||||
all_credentials.append(unreal_credentials)
|
||||
if settings.secrets.open_router_api_key:
|
||||
all_credentials.append(open_router_credentials)
|
||||
if settings.secrets.enrichlayer_api_key:
|
||||
all_credentials.append(enrichlayer_credentials)
|
||||
if settings.secrets.fal_api_key:
|
||||
all_credentials.append(fal_credentials)
|
||||
if settings.secrets.exa_api_key:
|
||||
@@ -363,21 +382,6 @@ class IntegrationCredentialsStore:
|
||||
|
||||
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
|
||||
|
||||
async def get_ayrshare_profile_key(self, user_id: str) -> SecretStr | None:
|
||||
"""Get the Ayrshare profile key for a user.
|
||||
|
||||
The profile key is used to authenticate API requests to Ayrshare's social media posting service.
|
||||
See https://www.ayrshare.com/docs/apis/profiles/overview for more details.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to get the profile key for
|
||||
|
||||
Returns:
|
||||
The profile key as a SecretStr if set, None otherwise
|
||||
"""
|
||||
user_integrations = await self._get_user_integrations(user_id)
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
async def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
|
||||
"""Set the Ayrshare profile key for a user.
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class ProviderName(str, Enum):
|
||||
GROQ = "groq"
|
||||
HTTP = "http"
|
||||
HUBSPOT = "hubspot"
|
||||
ENRICHLAYER = "enrichlayer"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
LLAMA_API = "llama_api"
|
||||
@@ -47,6 +48,7 @@ class ProviderName(str, Enum):
|
||||
TWITTER = "twitter"
|
||||
TODOIST = "todoist"
|
||||
UNREAL_SPEECH = "unreal_speech"
|
||||
V0 = "v0"
|
||||
ZEROBOUNCE = "zerobounce"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -8,10 +8,11 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.clients import (
|
||||
get_database_manager_client,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,7 +41,7 @@ class BlockErrorMonitor:
|
||||
|
||||
def __init__(self, include_top_blocks: int | None = None):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
self.notification_client = get_notification_manager_client()
|
||||
self.include_top_blocks = (
|
||||
include_top_blocks
|
||||
if include_top_blocks is not None
|
||||
@@ -107,7 +108,7 @@ class BlockErrorMonitor:
|
||||
) -> dict[str, BlockStatsWithSamples]:
|
||||
"""Get block execution stats using efficient SQL aggregation."""
|
||||
|
||||
result = execution_utils.get_db_client().get_block_error_stats(
|
||||
result = get_database_manager_client().get_block_error_stats(
|
||||
start_time, end_time
|
||||
)
|
||||
|
||||
@@ -197,7 +198,7 @@ class BlockErrorMonitor:
|
||||
) -> list[str]:
|
||||
"""Get error samples for a specific block - just a few recent ones."""
|
||||
# Only fetch a small number of recent failed executions for this specific block
|
||||
executions = execution_utils.get_db_client().get_node_executions(
|
||||
executions = get_database_manager_client().get_node_executions(
|
||||
block_ids=[block_id],
|
||||
statuses=[ExecutionStatus.FAILED],
|
||||
created_time_gte=start_time,
|
||||
|
||||
@@ -4,10 +4,11 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.clients import (
|
||||
get_database_manager_client,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,13 +26,13 @@ class LateExecutionMonitor:
|
||||
|
||||
def __init__(self):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
self.notification_client = get_notification_manager_client()
|
||||
|
||||
def check_late_executions(self) -> str:
|
||||
"""Check for late executions and send alerts if found."""
|
||||
|
||||
# Check for QUEUED executions
|
||||
queued_late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
queued_late_executions = get_database_manager_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.QUEUED],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(
|
||||
@@ -43,7 +44,7 @@ class LateExecutionMonitor:
|
||||
)
|
||||
|
||||
# Check for RUNNING executions stuck for more than 24 hours
|
||||
running_late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
running_late_executions = get_database_manager_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.RUNNING],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(hours=24)
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
import logging
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,11 +15,6 @@ class NotificationJobArgs(BaseModel):
|
||||
cron: str
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_manager_client():
|
||||
return get_service_client(NotificationManagerClient)
|
||||
|
||||
|
||||
def process_existing_batches(**kwargs):
|
||||
"""Process existing notification batches."""
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
|
||||
15
autogpt_platform/backend/backend/notification.py
Normal file
15
autogpt_platform/backend/backend/notification.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from backend.app import run_processes
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run the AutoGPT-server Notification Service.
|
||||
"""
|
||||
run_processes(
|
||||
NotificationManager(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,11 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import aio_pika
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data import rabbitmq
|
||||
@@ -26,25 +24,19 @@ from backend.data.notifications import (
|
||||
get_notif_data_type,
|
||||
get_summary_params_type,
|
||||
)
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
SyncRabbitMQ,
|
||||
)
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.metrics import discord_send_alert
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
UnhealthyServiceError,
|
||||
endpoint_to_sync,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -56,8 +48,6 @@ NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
|
||||
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
|
||||
EXCHANGES = [NOTIFICATION_EXCHANGE, DEAD_LETTER_EXCHANGE]
|
||||
|
||||
background_executor = ProcessPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
def create_notification_config() -> RabbitMQConfig:
|
||||
"""Create RabbitMQ configuration for notifications"""
|
||||
@@ -116,27 +106,6 @@ def create_notification_config() -> RabbitMQConfig:
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db():
|
||||
from backend.executor.database import DatabaseManagerClient
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_queue() -> SyncRabbitMQ:
|
||||
client = SyncRabbitMQ(create_notification_config())
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_notification_queue() -> AsyncRabbitMQ:
|
||||
client = AsyncRabbitMQ(create_notification_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
def get_routing_key(event_type: NotificationType) -> str:
|
||||
strategy = NotificationTypeOverride(event_type).strategy
|
||||
"""Get the appropriate routing key for an event"""
|
||||
@@ -161,6 +130,8 @@ def queue_notification(event: NotificationEventModel) -> NotificationResult:
|
||||
exchange = "notifications"
|
||||
routing_key = get_routing_key(event.type)
|
||||
|
||||
from backend.util.clients import get_notification_queue
|
||||
|
||||
queue = get_notification_queue()
|
||||
queue.publish_message(
|
||||
routing_key=routing_key,
|
||||
@@ -186,6 +157,8 @@ async def queue_notification_async(event: NotificationEventModel) -> Notificatio
|
||||
exchange = "notifications"
|
||||
routing_key = get_routing_key(event.type)
|
||||
|
||||
from backend.util.clients import get_async_notification_queue
|
||||
|
||||
queue = await get_async_notification_queue()
|
||||
await queue.publish_message(
|
||||
routing_key=routing_key,
|
||||
@@ -215,24 +188,33 @@ class NotificationManager(AppService):
|
||||
@property
|
||||
def rabbit(self) -> rabbitmq.AsyncRabbitMQ:
|
||||
"""Access the RabbitMQ service. Will raise if not configured."""
|
||||
if not self.rabbitmq_service:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
if not hasattr(self, "rabbitmq_service") or not self.rabbitmq_service:
|
||||
raise UnhealthyServiceError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_service
|
||||
|
||||
@property
|
||||
def rabbit_config(self) -> rabbitmq.RabbitMQConfig:
|
||||
"""Access the RabbitMQ config. Will raise if not configured."""
|
||||
if not self.rabbitmq_config:
|
||||
raise RuntimeError("RabbitMQ not configured for this service")
|
||||
raise UnhealthyServiceError("RabbitMQ not configured for this service")
|
||||
return self.rabbitmq_config
|
||||
|
||||
async def health_check(self) -> str:
|
||||
# Service is unhealthy if RabbitMQ is not ready
|
||||
if not hasattr(self, "rabbitmq_service") or not self.rabbitmq_service:
|
||||
raise UnhealthyServiceError("RabbitMQ not configured for this service")
|
||||
if not self.rabbitmq_service.is_ready:
|
||||
raise UnhealthyServiceError("RabbitMQ channel is not ready")
|
||||
return await super().health_check()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.notification_service_port
|
||||
|
||||
@expose
|
||||
def queue_weekly_summary(self):
|
||||
background_executor.submit(lambda: asyncio.run(self._queue_weekly_summary()))
|
||||
async def queue_weekly_summary(self):
|
||||
# Use the existing event loop instead of creating a new one with asyncio.run()
|
||||
asyncio.create_task(self._queue_weekly_summary())
|
||||
|
||||
async def _queue_weekly_summary(self):
|
||||
"""Process weekly summary for specified notification types"""
|
||||
@@ -241,10 +223,14 @@ class NotificationManager(AppService):
|
||||
processed_count = 0
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
start_time = current_time - timedelta(days=7)
|
||||
users = get_db().get_active_user_ids_in_timerange(
|
||||
logger.info(
|
||||
f"Querying for active users between {start_time} and {current_time}"
|
||||
)
|
||||
users = await get_database_manager_async_client().get_active_user_ids_in_timerange(
|
||||
end_time=current_time.isoformat(),
|
||||
start_time=start_time.isoformat(),
|
||||
)
|
||||
logger.info(f"Found {len(users)} active users in the last 7 days")
|
||||
for user in users:
|
||||
await self._queue_scheduled_notification(
|
||||
SummaryParamsEventModel(
|
||||
@@ -264,10 +250,15 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
|
||||
@expose
|
||||
def process_existing_batches(self, notification_types: list[NotificationType]):
|
||||
background_executor.submit(self._process_existing_batches, notification_types)
|
||||
async def process_existing_batches(
|
||||
self, notification_types: list[NotificationType]
|
||||
):
|
||||
# Use the existing event loop instead of creating a new process
|
||||
asyncio.create_task(self._process_existing_batches(notification_types))
|
||||
|
||||
def _process_existing_batches(self, notification_types: list[NotificationType]):
|
||||
async def _process_existing_batches(
|
||||
self, notification_types: list[NotificationType]
|
||||
):
|
||||
"""Process existing batches for specified notification types"""
|
||||
try:
|
||||
processed_count = 0
|
||||
@@ -275,14 +266,16 @@ class NotificationManager(AppService):
|
||||
|
||||
for notification_type in notification_types:
|
||||
# Get all batches for this notification type
|
||||
batches = get_db().get_all_batches_by_type(notification_type)
|
||||
batches = (
|
||||
await get_database_manager_async_client().get_all_batches_by_type(
|
||||
notification_type
|
||||
)
|
||||
)
|
||||
|
||||
for batch in batches:
|
||||
# Check if batch has aged out
|
||||
oldest_message = (
|
||||
get_db().get_user_notification_oldest_message_in_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
if not oldest_message:
|
||||
@@ -296,7 +289,9 @@ class NotificationManager(AppService):
|
||||
|
||||
# If batch has aged out, process it
|
||||
if oldest_message.created_at + max_delay < current_time:
|
||||
recipient_email = get_db().get_user_email_by_id(batch.user_id)
|
||||
recipient_email = await get_database_manager_async_client().get_user_email_by_id(
|
||||
batch.user_id
|
||||
)
|
||||
|
||||
if not recipient_email:
|
||||
logger.error(
|
||||
@@ -304,7 +299,7 @@ class NotificationManager(AppService):
|
||||
)
|
||||
continue
|
||||
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -313,12 +308,12 @@ class NotificationManager(AppService):
|
||||
f"User {batch.user_id} does not want to receive {notification_type} notifications"
|
||||
)
|
||||
# Clear the batch
|
||||
get_db().empty_user_notification_batch(
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data = get_db().get_user_notification_batch(
|
||||
batch_data = await get_database_manager_async_client().get_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -327,7 +322,7 @@ class NotificationManager(AppService):
|
||||
f"Batch data not found for user {batch.user_id}"
|
||||
)
|
||||
# Clear the batch
|
||||
get_db().empty_user_notification_batch(
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
@@ -363,7 +358,7 @@ class NotificationManager(AppService):
|
||||
)
|
||||
|
||||
# Clear the batch
|
||||
get_db().empty_user_notification_batch(
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -393,10 +388,13 @@ class NotificationManager(AppService):
|
||||
async def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
|
||||
"""Queue a scheduled notification - exposed method for other services to call"""
|
||||
try:
|
||||
logger.debug(f"Received Request to queue scheduled notification {event=}")
|
||||
logger.info(
|
||||
f"Queueing scheduled notification type={event.type} user_id={event.user_id}"
|
||||
)
|
||||
|
||||
exchange = "notifications"
|
||||
routing_key = get_routing_key(event.type)
|
||||
logger.info(f"Using routing key: {routing_key}")
|
||||
|
||||
# Publish to RabbitMQ
|
||||
await self.rabbit.publish_message(
|
||||
@@ -404,110 +402,131 @@ class NotificationManager(AppService):
|
||||
message=event.model_dump_json(),
|
||||
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
|
||||
)
|
||||
logger.info(f"Successfully queued notification for user {event.user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error queueing notification: {e}")
|
||||
|
||||
def _should_email_user_based_on_preference(
|
||||
async def _should_email_user_based_on_preference(
|
||||
self, user_id: str, event_type: NotificationType
|
||||
) -> bool:
|
||||
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
|
||||
validated_email = get_db().get_user_email_verification(user_id)
|
||||
preference = (
|
||||
get_db()
|
||||
.get_user_notification_preference(user_id)
|
||||
.preferences.get(event_type, True)
|
||||
validated_email = (
|
||||
await get_database_manager_async_client().get_user_email_verification(
|
||||
user_id
|
||||
)
|
||||
)
|
||||
preference = (
|
||||
await get_database_manager_async_client().get_user_notification_preference(
|
||||
user_id
|
||||
)
|
||||
).preferences.get(event_type, True)
|
||||
# only if both are true, should we email this person
|
||||
return validated_email and preference
|
||||
|
||||
def _gather_summary_data(
|
||||
async def _gather_summary_data(
|
||||
self, user_id: str, event_type: NotificationType, params: BaseSummaryParams
|
||||
) -> BaseSummaryData:
|
||||
"""Gathers the data to build a summary notification"""
|
||||
|
||||
logger.info(
|
||||
f"Gathering summary data for {user_id} and {event_type} wiht {params=}"
|
||||
f"Gathering summary data for {user_id} and {event_type} with {params=}"
|
||||
)
|
||||
|
||||
# total_credits_used = self.run_and_wait(
|
||||
# get_total_credits_used(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
# total_executions = self.run_and_wait(
|
||||
# get_total_executions(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
# most_used_agent = self.run_and_wait(
|
||||
# get_most_used_agent(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
# execution_times = self.run_and_wait(
|
||||
# get_execution_time(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
# runs = self.run_and_wait(
|
||||
# get_runs(user_id, start_time, end_time)
|
||||
# )
|
||||
total_credits_used = 3.0
|
||||
total_executions = 2
|
||||
most_used_agent = {"name": "Some"}
|
||||
execution_times = [1, 2, 3]
|
||||
runs = [{"status": "COMPLETED"}, {"status": "FAILED"}]
|
||||
|
||||
successful_runs = len([run for run in runs if run["status"] == "COMPLETED"])
|
||||
failed_runs = len([run for run in runs if run["status"] != "COMPLETED"])
|
||||
average_execution_time = (
|
||||
sum(execution_times) / len(execution_times) if execution_times else 0
|
||||
)
|
||||
# cost_breakdown = self.run_and_wait(
|
||||
# get_cost_breakdown(user_id, start_time, end_time)
|
||||
# )
|
||||
|
||||
cost_breakdown = {
|
||||
"agent1": 1.0,
|
||||
"agent2": 2.0,
|
||||
}
|
||||
|
||||
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
|
||||
params, DailySummaryParams
|
||||
):
|
||||
return DailySummaryData(
|
||||
total_credits_used=total_credits_used,
|
||||
total_executions=total_executions,
|
||||
most_used_agent=most_used_agent["name"],
|
||||
total_execution_time=sum(execution_times),
|
||||
successful_runs=successful_runs,
|
||||
failed_runs=failed_runs,
|
||||
average_execution_time=average_execution_time,
|
||||
cost_breakdown=cost_breakdown,
|
||||
date=params.date,
|
||||
try:
|
||||
# Get summary data from the database
|
||||
summary_data = await get_database_manager_async_client().get_user_execution_summary_data(
|
||||
user_id=user_id,
|
||||
start_time=params.start_date,
|
||||
end_time=params.end_date,
|
||||
)
|
||||
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
|
||||
params, WeeklySummaryParams
|
||||
):
|
||||
return WeeklySummaryData(
|
||||
total_credits_used=total_credits_used,
|
||||
total_executions=total_executions,
|
||||
most_used_agent=most_used_agent["name"],
|
||||
total_execution_time=sum(execution_times),
|
||||
successful_runs=successful_runs,
|
||||
failed_runs=failed_runs,
|
||||
average_execution_time=average_execution_time,
|
||||
cost_breakdown=cost_breakdown,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid event type or params")
|
||||
|
||||
def _should_batch(
|
||||
# Extract data from summary
|
||||
total_credits_used = summary_data.total_credits_used
|
||||
total_executions = summary_data.total_executions
|
||||
most_used_agent = summary_data.most_used_agent
|
||||
successful_runs = summary_data.successful_runs
|
||||
failed_runs = summary_data.failed_runs
|
||||
total_execution_time = summary_data.total_execution_time
|
||||
average_execution_time = summary_data.average_execution_time
|
||||
cost_breakdown = summary_data.cost_breakdown
|
||||
|
||||
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
|
||||
params, DailySummaryParams
|
||||
):
|
||||
return DailySummaryData(
|
||||
total_credits_used=total_credits_used,
|
||||
total_executions=total_executions,
|
||||
most_used_agent=most_used_agent,
|
||||
total_execution_time=total_execution_time,
|
||||
successful_runs=successful_runs,
|
||||
failed_runs=failed_runs,
|
||||
average_execution_time=average_execution_time,
|
||||
cost_breakdown=cost_breakdown,
|
||||
date=params.date,
|
||||
)
|
||||
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
|
||||
params, WeeklySummaryParams
|
||||
):
|
||||
return WeeklySummaryData(
|
||||
total_credits_used=total_credits_used,
|
||||
total_executions=total_executions,
|
||||
most_used_agent=most_used_agent,
|
||||
total_execution_time=total_execution_time,
|
||||
successful_runs=successful_runs,
|
||||
failed_runs=failed_runs,
|
||||
average_execution_time=average_execution_time,
|
||||
cost_breakdown=cost_breakdown,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid event type or params")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to gather summary data: {e}")
|
||||
# Return sensible defaults in case of error
|
||||
if event_type == NotificationType.DAILY_SUMMARY and isinstance(
|
||||
params, DailySummaryParams
|
||||
):
|
||||
return DailySummaryData(
|
||||
total_credits_used=0.0,
|
||||
total_executions=0,
|
||||
most_used_agent="No data available",
|
||||
total_execution_time=0.0,
|
||||
successful_runs=0,
|
||||
failed_runs=0,
|
||||
average_execution_time=0.0,
|
||||
cost_breakdown={},
|
||||
date=params.date,
|
||||
)
|
||||
elif event_type == NotificationType.WEEKLY_SUMMARY and isinstance(
|
||||
params, WeeklySummaryParams
|
||||
):
|
||||
return WeeklySummaryData(
|
||||
total_credits_used=0.0,
|
||||
total_executions=0,
|
||||
most_used_agent="No data available",
|
||||
total_execution_time=0.0,
|
||||
successful_runs=0,
|
||||
failed_runs=0,
|
||||
average_execution_time=0.0,
|
||||
cost_breakdown={},
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid event type or params") from e
|
||||
|
||||
async def _should_batch(
|
||||
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
|
||||
) -> bool:
|
||||
|
||||
get_db().create_or_add_to_user_notification_batch(user_id, event_type, event)
|
||||
await get_database_manager_async_client().create_or_add_to_user_notification_batch(
|
||||
user_id, event_type, event
|
||||
)
|
||||
|
||||
oldest_message = get_db().get_user_notification_oldest_message_in_batch(
|
||||
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
|
||||
user_id, event_type
|
||||
)
|
||||
if not oldest_message:
|
||||
@@ -537,7 +556,7 @@ class NotificationManager(AppService):
|
||||
logger.error(f"Error parsing message due to non matching schema {e}")
|
||||
return None
|
||||
|
||||
def _process_admin_message(self, message: str) -> bool:
|
||||
async def _process_admin_message(self, message: str) -> bool:
|
||||
"""Process a single notification, sending to an admin, returning whether to put into the failed queue"""
|
||||
try:
|
||||
event = self._parse_message(message)
|
||||
@@ -551,7 +570,7 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Error processing notification for admin queue: {e}")
|
||||
return False
|
||||
|
||||
def _process_immediate(self, message: str) -> bool:
|
||||
async def _process_immediate(self, message: str) -> bool:
|
||||
"""Process a single notification immediately, returning whether to put into the failed queue"""
|
||||
try:
|
||||
event = self._parse_message(message)
|
||||
@@ -559,12 +578,16 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.debug(f"Processing immediate notification: {event}")
|
||||
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
event.user_id, event.type
|
||||
)
|
||||
if not should_send:
|
||||
@@ -586,7 +609,7 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Error processing notification for immediate queue: {e}")
|
||||
return False
|
||||
|
||||
def _process_batch(self, message: str) -> bool:
|
||||
async def _process_batch(self, message: str) -> bool:
|
||||
"""Process a single notification with a batching strategy, returning whether to put into the failed queue"""
|
||||
try:
|
||||
event = self._parse_message(message)
|
||||
@@ -594,12 +617,16 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.info(f"Processing batch notification: {event}")
|
||||
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
event.user_id, event.type
|
||||
)
|
||||
if not should_send:
|
||||
@@ -608,12 +635,16 @@ class NotificationManager(AppService):
|
||||
)
|
||||
return True
|
||||
|
||||
should_send = self._should_batch(event.user_id, event.type, event)
|
||||
should_send = await self._should_batch(event.user_id, event.type, event)
|
||||
|
||||
if not should_send:
|
||||
logger.info("Batch not old enough to send")
|
||||
return False
|
||||
batch = get_db().get_user_notification_batch(event.user_id, event.type)
|
||||
batch = (
|
||||
await get_database_manager_async_client().get_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
)
|
||||
if not batch or not batch.notifications:
|
||||
logger.error(f"Batch not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -714,7 +745,9 @@ class NotificationManager(AppService):
|
||||
logger.info(
|
||||
f"Successfully sent all {successfully_sent_count} notifications, clearing batch"
|
||||
)
|
||||
get_db().empty_user_notification_batch(event.user_id, event.type)
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Only sent {successfully_sent_count} of {len(batch_messages)} notifications. "
|
||||
@@ -725,7 +758,7 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Error processing notification for batch queue: {e}")
|
||||
return False
|
||||
|
||||
def _process_summary(self, message: str) -> bool:
|
||||
async def _process_summary(self, message: str) -> bool:
|
||||
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
|
||||
try:
|
||||
logger.info(f"Processing summary notification: {message}")
|
||||
@@ -736,11 +769,15 @@ class NotificationManager(AppService):
|
||||
|
||||
logger.info(f"Processing summary notification: {model}")
|
||||
|
||||
recipient_email = get_db().get_user_email_by_id(event.user_id)
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
should_send = self._should_email_user_based_on_preference(
|
||||
should_send = await self._should_email_user_based_on_preference(
|
||||
event.user_id, event.type
|
||||
)
|
||||
if not should_send:
|
||||
@@ -749,7 +786,7 @@ class NotificationManager(AppService):
|
||||
)
|
||||
return True
|
||||
|
||||
summary_data = self._gather_summary_data(
|
||||
summary_data = await self._gather_summary_data(
|
||||
event.user_id, event.type, model.data
|
||||
)
|
||||
|
||||
@@ -775,7 +812,7 @@ class NotificationManager(AppService):
|
||||
async def _consume_queue(
|
||||
self,
|
||||
queue: aio_pika.abc.AbstractQueue,
|
||||
process_func: Callable[[str], bool],
|
||||
process_func: Callable[[str], Awaitable[bool]],
|
||||
queue_name: str,
|
||||
):
|
||||
"""Continuously consume messages from a queue using async iteration"""
|
||||
@@ -789,7 +826,7 @@ class NotificationManager(AppService):
|
||||
|
||||
try:
|
||||
async with message.process():
|
||||
result = process_func(message.body.decode())
|
||||
result = await process_func(message.body.decode())
|
||||
if not result:
|
||||
# Message will be rejected when exiting context without exception
|
||||
raise aio_pika.exceptions.MessageProcessError(
|
||||
@@ -888,6 +925,8 @@ class NotificationManagerClient(AppServiceClient):
|
||||
def get_service_type(cls):
|
||||
return NotificationManager
|
||||
|
||||
process_existing_batches = NotificationManager.process_existing_batches
|
||||
queue_weekly_summary = NotificationManager.queue_weekly_summary
|
||||
process_existing_batches = endpoint_to_sync(
|
||||
NotificationManager.process_existing_batches
|
||||
)
|
||||
queue_weekly_summary = endpoint_to_sync(NotificationManager.queue_weekly_summary)
|
||||
discord_system_alert = endpoint_to_sync(NotificationManager.discord_system_alert)
|
||||
|
||||
@@ -5,23 +5,64 @@ data.start_date: the start date of the summary
|
||||
data.end_date: the end date of the summary
|
||||
data.total_credits_used: the total credits used during the summary
|
||||
data.total_executions: the total number of executions during the summary
|
||||
data.most_used_agent: the most used agent's nameduring the summary
|
||||
data.most_used_agent: the most used agent's name during the summary
|
||||
data.total_execution_time: the total execution time during the summary
|
||||
data.successful_runs: the total number of successful runs during the summary
|
||||
data.failed_runs: the total number of failed runs during the summary
|
||||
data.average_execution_time: the average execution time during the summary
|
||||
data.cost_breakdown: the cost breakdown during the summary
|
||||
data.cost_breakdown: the cost breakdown during the summary (dict mapping agent names to credit amounts)
|
||||
#}
|
||||
|
||||
<h1>Weekly Summary</h1>
|
||||
<h1 style="color: #5D23BB; font-size: 32px; font-weight: 600; margin-bottom: 25px; margin-top: 0;">
|
||||
Weekly Summary
|
||||
</h1>
|
||||
|
||||
<p>Start Date: {{ data.start_date }}</p>
|
||||
<p>End Date: {{ data.end_date }}</p>
|
||||
<p>Total Credits Used: {{ data.total_credits_used }}</p>
|
||||
<p>Total Executions: {{ data.total_executions }}</p>
|
||||
<p>Most Used Agent: {{ data.most_used_agent }}</p>
|
||||
<p>Total Execution Time: {{ data.total_execution_time }}</p>
|
||||
<p>Successful Runs: {{ data.successful_runs }}</p>
|
||||
<p>Failed Runs: {{ data.failed_runs }}</p>
|
||||
<p>Average Execution Time: {{ data.average_execution_time }}</p>
|
||||
<p>Cost Breakdown: {{ data.cost_breakdown }}</p>
|
||||
<h2 style="color: #070629; font-size: 24px; font-weight: 500; margin-bottom: 20px;">
|
||||
Your Agent Activity: {{ data.start_date.strftime('%B %-d') }} – {{ data.end_date.strftime('%B %-d') }}
|
||||
</h2>
|
||||
|
||||
<div style="background-color: #ffffff; border-radius: 8px; padding: 20px; margin-bottom: 25px;">
|
||||
<ul style="list-style-type: disc; padding-left: 20px; margin: 0;">
|
||||
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
|
||||
<strong>Total Executions:</strong> {{ data.total_executions }}
|
||||
</li>
|
||||
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
|
||||
<strong>Total Credits Used:</strong> {{ data.total_credits_used|format("%.2f") }}
|
||||
</li>
|
||||
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
|
||||
<strong>Total Execution Time:</strong> {{ data.total_execution_time|format("%.1f") }} seconds
|
||||
</li>
|
||||
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
|
||||
<strong>Successful Runs:</strong> {{ data.successful_runs }}
|
||||
</li>
|
||||
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
|
||||
<strong>Failed Runs:</strong> {{ data.failed_runs }}
|
||||
</li>
|
||||
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
|
||||
<strong>Average Execution Time:</strong> {{ data.average_execution_time|format("%.1f") }} seconds
|
||||
</li>
|
||||
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
|
||||
<strong>Most Used Agent:</strong> {{ data.most_used_agent }}
|
||||
</li>
|
||||
{% if data.cost_breakdown %}
|
||||
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 8px;">
|
||||
<strong>Cost Breakdown:</strong>
|
||||
<ul style="list-style-type: disc; padding-left: 40px; margin-top: 8px;">
|
||||
{% for agent_name, credits in data.cost_breakdown.items() %}
|
||||
<li style="font-size: 16px; line-height: 1.8; margin-bottom: 4px;">
|
||||
{{ agent_name }}: {{ credits|format("%.2f") }} credits
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</li>
|
||||
{% endif %}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<p style="font-size: 16px; line-height: 165%; margin-top: 20px; margin-bottom: 10px;">
|
||||
Thank you for being a part of the AutoGPT community! 🎉
|
||||
</p>
|
||||
|
||||
<p style="font-size: 16px; line-height: 165%; margin-bottom: 0;">
|
||||
Join the conversation on <a href="https://discord.gg/autogpt" style="color: #4285F4; text-decoration: underline;">Discord here</a>.
|
||||
</p>
|
||||
@@ -1,6 +1,5 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor.scheduler import Scheduler
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
|
||||
def main():
|
||||
@@ -8,7 +7,6 @@ def main():
|
||||
Run all the processes required for the AutoGPT-server Scheduling System.
|
||||
"""
|
||||
run_processes(
|
||||
NotificationManager(),
|
||||
Scheduler(),
|
||||
)
|
||||
|
||||
|
||||
@@ -29,7 +29,9 @@ from backend.data.model import (
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
@@ -585,7 +587,10 @@ async def get_ayrshare_sso_url(
|
||||
# Ayrshare profile key is stored in the credentials store
|
||||
# It is generated when creating a new profile, if there is no profile key,
|
||||
# we create a new profile and store the profile key in the credentials store
|
||||
profile_key = await creds_manager.store.get_ayrshare_profile_key(user_id)
|
||||
|
||||
user_integrations: UserIntegrations = await get_user_integrations(user_id)
|
||||
profile_key = user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
if not profile_key:
|
||||
logger.debug(f"Creating new Ayrshare profile for user {user_id}")
|
||||
try:
|
||||
@@ -629,7 +634,7 @@ async def get_ayrshare_sso_url(
|
||||
# SocialPlatform.TELEGRAM,
|
||||
# SocialPlatform.GOOGLE_MY_BUSINESS,
|
||||
# SocialPlatform.PINTEREST,
|
||||
# SocialPlatform.TIKTOK,
|
||||
SocialPlatform.TIKTOK,
|
||||
# SocialPlatform.BLUESKY,
|
||||
# SocialPlatform.SNAPCHAT,
|
||||
# SocialPlatform.THREADS,
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from supabase import Client, create_client
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
return create_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
@@ -9,11 +9,6 @@ import fastapi.responses
|
||||
import pydantic
|
||||
import starlette.middleware.cors
|
||||
import uvicorn
|
||||
from autogpt_libs.feature_flag.client import (
|
||||
initialize_launchdarkly,
|
||||
shutdown_launchdarkly,
|
||||
)
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
@@ -40,6 +35,9 @@ from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
|
||||
settings = backend.util.settings.Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -75,6 +73,12 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
||||
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
||||
@@ -224,6 +228,8 @@ app.mount("/external-api", external_app)
|
||||
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
async def health():
|
||||
if not backend.data.db.is_connected():
|
||||
raise UnhealthyServiceError("Database is not connected")
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@@ -240,7 +246,7 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
server_app,
|
||||
host=backend.util.settings.Config().agent_api_host,
|
||||
port=backend.util.settings.Config().agent_api_port,
|
||||
log_config=generate_uvicorn_config(),
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
|
||||
@@ -8,8 +8,6 @@ from typing import Annotated, Any, Sequence
|
||||
import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.feature_flag.client import feature_flag
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -51,7 +49,6 @@ from backend.data.credit import (
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -84,18 +81,14 @@ from backend.server.model import (
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.server.utils import get_user_id
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.feature_flag import feature_flag
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_scheduler_client() -> scheduler.SchedulerClient:
|
||||
return get_service_client(scheduler.SchedulerClient, health_check=False)
|
||||
|
||||
|
||||
def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
"""Create standardized file size error response."""
|
||||
return HTTPException(
|
||||
@@ -104,11 +97,6 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def execution_event_bus() -> AsyncRedisExecutionEventBus:
|
||||
return AsyncRedisExecutionEventBus()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -470,12 +458,16 @@ async def stripe_webhook(request: Request):
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.secrets.stripe_webhook_secret
|
||||
)
|
||||
except ValueError:
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400)
|
||||
except stripe.SignatureVerificationError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}"
|
||||
)
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}"
|
||||
)
|
||||
|
||||
if (
|
||||
event["type"] == "checkout.session.completed"
|
||||
@@ -688,7 +680,15 @@ async def update_graph(
|
||||
# Handle deactivation of the previously active version
|
||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||
|
||||
return new_graph_version
|
||||
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||
graph_id,
|
||||
new_graph_version.version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
assert new_graph_version_with_subgraphs # make type checker happy
|
||||
return new_graph_version_with_subgraphs
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
@@ -753,15 +753,27 @@ async def execute_graph(
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
graph_exec = await execution_utils.add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
try:
|
||||
graph_exec = await execution_utils.add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
)
|
||||
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
|
||||
except GraphValidationError as e:
|
||||
# Return structured validation errors that the frontend can parse
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"type": "validation_error",
|
||||
"message": e.message,
|
||||
# TODO: only return node-specific errors if user has access to graph
|
||||
"node_errors": e.node_errors,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -912,7 +924,7 @@ async def create_graph_execution_schedule(
|
||||
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
||||
)
|
||||
|
||||
return await execution_scheduler_client().add_execution_schedule(
|
||||
return await get_scheduler_client().add_execution_schedule(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
@@ -933,7 +945,7 @@ async def list_graph_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str = Path(),
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(
|
||||
return await get_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
@@ -948,7 +960,7 @@ async def list_graph_execution_schedules(
|
||||
async def list_all_graphs_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
return await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
@@ -962,7 +974,7 @@ async def delete_graph_execution_schedule(
|
||||
schedule_id: str = Path(..., description="ID of the schedule to delete"),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
await get_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# AutoMod integration for content moderation
|
||||
353
autogpt_platform/backend/backend/server/v2/AutoMod/manager.py
Normal file
353
autogpt_platform/backend/backend/server/v2/AutoMod/manager.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.server.v2.AutoMod.models import (
|
||||
AutoModRequest,
|
||||
AutoModResponse,
|
||||
ModerationConfig,
|
||||
)
|
||||
from backend.util.exceptions import ModerationError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoModManager:
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._load_config()
|
||||
|
||||
def _load_config(self) -> ModerationConfig:
|
||||
"""Load AutoMod configuration from settings"""
|
||||
settings = Settings()
|
||||
return ModerationConfig(
|
||||
enabled=settings.config.automod_enabled,
|
||||
api_url=settings.config.automod_api_url,
|
||||
api_key=settings.secrets.automod_api_key,
|
||||
timeout=settings.config.automod_timeout,
|
||||
retry_attempts=settings.config.automod_retry_attempts,
|
||||
retry_delay=settings.config.automod_retry_delay,
|
||||
fail_open=settings.config.automod_fail_open,
|
||||
)
|
||||
|
||||
async def moderate_graph_execution_inputs(
|
||||
self, db_client: "DatabaseManagerAsyncClient", graph_exec, timeout: int = 10
|
||||
) -> Exception | None:
|
||||
"""
|
||||
Complete input moderation flow for graph execution
|
||||
Returns: error_if_failed (None means success)
|
||||
"""
|
||||
if not self.config.enabled:
|
||||
return None
|
||||
|
||||
# Check if AutoMod feature is enabled for this user
|
||||
if not await is_feature_enabled(Flag.AUTOMOD, graph_exec.user_id):
|
||||
logger.debug(f"AutoMod feature not enabled for user {graph_exec.user_id}")
|
||||
return None
|
||||
|
||||
# Get graph model and collect all inputs
|
||||
graph_model = await db_client.get_graph(
|
||||
graph_exec.graph_id,
|
||||
user_id=graph_exec.user_id,
|
||||
version=graph_exec.graph_version,
|
||||
)
|
||||
|
||||
if not graph_model or not graph_model.nodes:
|
||||
return None
|
||||
|
||||
all_inputs = []
|
||||
for node in graph_model.nodes:
|
||||
if node.input_default:
|
||||
all_inputs.extend(str(v) for v in node.input_default.values() if v)
|
||||
if (masks := graph_exec.nodes_input_masks) and (mask := masks.get(node.id)):
|
||||
all_inputs.extend(str(v) for v in mask.values() if v)
|
||||
|
||||
if not all_inputs:
|
||||
return None
|
||||
|
||||
# Combine all content and moderate directly
|
||||
content = " ".join(all_inputs)
|
||||
|
||||
# Run moderation
|
||||
logger.warning(
|
||||
f"Moderating inputs for graph execution {graph_exec.graph_exec_id}"
|
||||
)
|
||||
try:
|
||||
moderation_passed = await self._moderate_content(
|
||||
content,
|
||||
{
|
||||
"user_id": graph_exec.user_id,
|
||||
"graph_id": graph_exec.graph_id,
|
||||
"graph_exec_id": graph_exec.graph_exec_id,
|
||||
"moderation_type": "execution_input",
|
||||
},
|
||||
)
|
||||
|
||||
if not moderation_passed:
|
||||
logger.warning(
|
||||
f"Moderation failed for graph execution {graph_exec.graph_exec_id}"
|
||||
)
|
||||
# Update node statuses for frontend display before raising error
|
||||
await self._update_failed_nodes_for_moderation(
|
||||
db_client, graph_exec.graph_exec_id, "input"
|
||||
)
|
||||
|
||||
return ModerationError(
|
||||
message="Execution failed due to input content moderation",
|
||||
user_id=graph_exec.user_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
moderation_type="input",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Input moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation"
|
||||
)
|
||||
return None # Bypass moderation on timeout
|
||||
except Exception as e:
|
||||
logger.warning(f"Input moderation execution failed: {e}")
|
||||
return ModerationError(
|
||||
message="Execution failed due to input content moderation error",
|
||||
user_id=graph_exec.user_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
moderation_type="input",
|
||||
)
|
||||
|
||||
async def moderate_graph_execution_outputs(
|
||||
self,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
timeout: int = 10,
|
||||
) -> Exception | None:
|
||||
"""
|
||||
Complete output moderation flow for graph execution
|
||||
Returns: error_if_failed (None means success)
|
||||
"""
|
||||
if not self.config.enabled:
|
||||
return None
|
||||
|
||||
# Check if AutoMod feature is enabled for this user
|
||||
if not await is_feature_enabled(Flag.AUTOMOD, user_id):
|
||||
logger.debug(f"AutoMod feature not enabled for user {user_id}")
|
||||
return None
|
||||
|
||||
# Get completed executions and collect outputs
|
||||
completed_executions = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=[ExecutionStatus.COMPLETED], include_exec_data=True
|
||||
)
|
||||
|
||||
if not completed_executions:
|
||||
return None
|
||||
|
||||
all_outputs = []
|
||||
for exec_entry in completed_executions:
|
||||
if exec_entry.output_data:
|
||||
all_outputs.extend(str(v) for v in exec_entry.output_data.values() if v)
|
||||
|
||||
if not all_outputs:
|
||||
return None
|
||||
|
||||
# Combine all content and moderate directly
|
||||
content = " ".join(all_outputs)
|
||||
|
||||
# Run moderation
|
||||
logger.warning(f"Moderating outputs for graph execution {graph_exec_id}")
|
||||
try:
|
||||
moderation_passed = await self._moderate_content(
|
||||
content,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"graph_id": graph_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"moderation_type": "execution_output",
|
||||
},
|
||||
)
|
||||
|
||||
if not moderation_passed:
|
||||
logger.warning(f"Moderation failed for graph execution {graph_exec_id}")
|
||||
# Update node statuses for frontend display before raising error
|
||||
await self._update_failed_nodes_for_moderation(
|
||||
db_client, graph_exec_id, "output"
|
||||
)
|
||||
|
||||
return ModerationError(
|
||||
message="Execution failed due to output content moderation",
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
moderation_type="output",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Output moderation timed out for graph execution {graph_exec_id}, bypassing moderation"
|
||||
)
|
||||
return None # Bypass moderation on timeout
|
||||
except Exception as e:
|
||||
logger.warning(f"Output moderation execution failed: {e}")
|
||||
return ModerationError(
|
||||
message="Execution failed due to output content moderation error",
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
moderation_type="output",
|
||||
)
|
||||
|
||||
async def _update_failed_nodes_for_moderation(
|
||||
self,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
graph_exec_id: str,
|
||||
moderation_type: Literal["input", "output"],
|
||||
):
|
||||
"""Update node execution statuses for frontend display when moderation fails"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.executor.manager import send_async_execution_update
|
||||
|
||||
if moderation_type == "input":
|
||||
# For input moderation, mark queued/running/incomplete nodes as failed
|
||||
target_statuses = [
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
]
|
||||
else:
|
||||
# For output moderation, mark completed nodes as failed
|
||||
target_statuses = [ExecutionStatus.COMPLETED]
|
||||
|
||||
# Get the executions that need to be updated
|
||||
executions_to_update = await db_client.get_node_executions(
|
||||
graph_exec_id, statuses=target_statuses, include_exec_data=True
|
||||
)
|
||||
|
||||
if not executions_to_update:
|
||||
return
|
||||
|
||||
# Prepare database update tasks
|
||||
exec_updates = []
|
||||
for exec_entry in executions_to_update:
|
||||
# Collect all input and output names to clear
|
||||
cleared_inputs = {}
|
||||
cleared_outputs = {}
|
||||
|
||||
if exec_entry.input_data:
|
||||
for name in exec_entry.input_data.keys():
|
||||
cleared_inputs[name] = ["Failed due to content moderation"]
|
||||
|
||||
if exec_entry.output_data:
|
||||
for name in exec_entry.output_data.keys():
|
||||
cleared_outputs[name] = ["Failed due to content moderation"]
|
||||
|
||||
# Add update task to list
|
||||
exec_updates.append(
|
||||
db_client.update_node_execution_status(
|
||||
exec_entry.node_exec_id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats={
|
||||
"error": "Failed due to content moderation",
|
||||
"cleared_inputs": cleared_inputs,
|
||||
"cleared_outputs": cleared_outputs,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Execute all database updates in parallel
|
||||
updated_execs = await asyncio.gather(*exec_updates)
|
||||
|
||||
# Send all websocket updates in parallel
|
||||
await asyncio.gather(
|
||||
*[
|
||||
send_async_execution_update(updated_exec)
|
||||
for updated_exec in updated_execs
|
||||
]
|
||||
)
|
||||
|
||||
async def _moderate_content(self, content: str, metadata: dict[str, Any]) -> bool:
|
||||
"""Moderate content using AutoMod API
|
||||
|
||||
Returns:
|
||||
True: Content approved or timeout occurred
|
||||
False: Content rejected by moderation
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: When moderation times out (should be bypassed)
|
||||
"""
|
||||
try:
|
||||
request_data = AutoModRequest(
|
||||
type="text",
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
response = await self._make_request(request_data)
|
||||
|
||||
if response.success and response.status == "approved":
|
||||
logger.debug(
|
||||
f"Content approved for {metadata.get('graph_exec_id', 'unknown')}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
reasons = [r.reason for r in response.moderation_results if r.reason]
|
||||
error_msg = f"Content rejected by AutoMod: {'; '.join(reasons)}"
|
||||
logger.warning(f"Content rejected: {error_msg}")
|
||||
return False
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Re-raise timeout to be handled by calling methods
|
||||
logger.warning(
|
||||
f"AutoMod API timeout for {metadata.get('graph_exec_id', 'unknown')}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"AutoMod moderation error: {e}")
|
||||
return self.config.fail_open
|
||||
|
||||
async def _make_request(self, request_data: AutoModRequest) -> AutoModResponse:
|
||||
"""Make HTTP request to AutoMod API using the standard request utility"""
|
||||
url = f"{self.config.api_url}/moderate"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-API-Key": self.config.api_key.strip(),
|
||||
}
|
||||
|
||||
# Create requests instance with timeout and retry configuration
|
||||
requests = Requests(
|
||||
extra_headers=headers,
|
||||
retry_max_wait=float(self.config.timeout),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await requests.post(
|
||||
url, json=request_data.model_dump(), timeout=self.config.timeout
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
return AutoModResponse.model_validate(response_data)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Re-raise timeout error to be caught by _moderate_content
|
||||
raise
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
raise Exception(f"Invalid response from AutoMod API: {e}")
|
||||
except Exception as e:
|
||||
# Check if this is an aiohttp timeout that we should convert
|
||||
if "timeout" in str(e).lower():
|
||||
raise asyncio.TimeoutError(f"AutoMod API request timed out: {e}")
|
||||
raise Exception(f"AutoMod API request failed: {e}")
|
||||
|
||||
|
||||
# Global instance
|
||||
automod_manager = AutoModManager()
|
||||
57
autogpt_platform/backend/backend/server/v2/AutoMod/models.py
Normal file
57
autogpt_platform/backend/backend/server/v2/AutoMod/models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AutoModRequest(BaseModel):
|
||||
"""Request model for AutoMod API"""
|
||||
|
||||
type: str = Field(..., description="Content type - 'text', 'image', 'video'")
|
||||
content: str = Field(..., description="The content to moderate")
|
||||
metadata: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Additional context about the content"
|
||||
)
|
||||
|
||||
|
||||
class ModerationResult(BaseModel):
|
||||
"""Individual moderation result"""
|
||||
|
||||
decision: str = Field(
|
||||
..., description="Moderation decision: 'approved', 'rejected', 'flagged'"
|
||||
)
|
||||
reason: Optional[str] = Field(default=None, description="Reason for the decision")
|
||||
|
||||
|
||||
class AutoModResponse(BaseModel):
|
||||
"""Response model for AutoMod API"""
|
||||
|
||||
success: bool = Field(..., description="Whether the request was successful")
|
||||
status: str = Field(
|
||||
..., description="Overall status: 'approved', 'rejected', 'flagged', 'pending'"
|
||||
)
|
||||
moderation_results: List[ModerationResult] = Field(
|
||||
default_factory=list, description="List of moderation results"
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseModel):
|
||||
"""Configuration for AutoMod integration"""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether moderation is enabled")
|
||||
api_url: str = Field(default="", description="AutoMod API base URL")
|
||||
api_key: str = Field(..., description="AutoMod API key")
|
||||
timeout: int = Field(default=30, description="Request timeout in seconds")
|
||||
retry_attempts: int = Field(default=3, description="Number of retry attempts")
|
||||
retry_delay: float = Field(
|
||||
default=1.0, description="Delay between retries in seconds"
|
||||
)
|
||||
fail_open: bool = Field(
|
||||
default=False,
|
||||
description="If True, allow execution to continue if moderation fails",
|
||||
)
|
||||
moderate_inputs: bool = Field(
|
||||
default=True, description="Whether to moderate block inputs"
|
||||
)
|
||||
moderate_outputs: bool = Field(
|
||||
default=True, description="Whether to moderate block outputs"
|
||||
)
|
||||
@@ -241,7 +241,11 @@ async def get_library_agent_by_graph_id(
|
||||
)
|
||||
if not agent:
|
||||
return None
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
|
||||
assert agent.AgentGraph # make type checker happy
|
||||
# Include sub-graphs so we can make a full credentials input schema
|
||||
sub_graphs = await graph_db.get_sub_graphs(agent.AgentGraph)
|
||||
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agent by graph ID: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
|
||||
|
||||
@@ -33,30 +33,30 @@ async def check_media_exists(user_id: str, filename: str) -> str | None:
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise MissingConfigError("GCS media bucket is not configured")
|
||||
|
||||
async_client = async_storage.Storage()
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
# Check images
|
||||
image_path = f"users/{user_id}/images/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, image_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{image_path}"
|
||||
except Exception:
|
||||
# File doesn't exist, continue to check videos
|
||||
pass
|
||||
# Check images
|
||||
image_path = f"users/{user_id}/images/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, image_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{image_path}"
|
||||
except Exception:
|
||||
# File doesn't exist, continue to check videos
|
||||
pass
|
||||
|
||||
# Check videos
|
||||
video_path = f"users/{user_id}/videos/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, video_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{video_path}"
|
||||
except Exception:
|
||||
# File doesn't exist
|
||||
pass
|
||||
# Check videos
|
||||
video_path = f"users/{user_id}/videos/{filename}"
|
||||
try:
|
||||
await async_client.download_metadata(bucket_name, video_path)
|
||||
# If we get here, the file exists - construct public URL
|
||||
return f"https://storage.googleapis.com/{bucket_name}/{video_path}"
|
||||
except Exception:
|
||||
# File doesn't exist
|
||||
pass
|
||||
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def upload_media(
|
||||
@@ -177,22 +177,24 @@ async def upload_media(
|
||||
storage_path = f"users/{user_id}/{media_type}/{unique_filename}"
|
||||
|
||||
try:
|
||||
async_client = async_storage.Storage()
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
file_bytes = await file.read()
|
||||
await scan_content_safe(file_bytes, filename=unique_filename)
|
||||
file_bytes = await file.read()
|
||||
await scan_content_safe(file_bytes, filename=unique_filename)
|
||||
|
||||
# Upload using pure async client
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
# Upload using pure async client
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
# Construct public URL
|
||||
public_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
# Construct public URL
|
||||
public_url = (
|
||||
f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully uploaded file to: {storage_path}")
|
||||
return public_url
|
||||
logger.info(f"Successfully uploaded file to: {storage_path}")
|
||||
return public_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GCS storage error: {str(e)}")
|
||||
|
||||
@@ -26,6 +26,10 @@ def mock_storage_client(mocker):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.upload = AsyncMock()
|
||||
|
||||
# Mock context manager methods
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Mock the constructor to return our mock client
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.media.async_storage.Storage", return_value=mock_client
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Protocol
|
||||
import pydantic
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
@@ -309,7 +308,7 @@ class WebsocketServer(AppProcess):
|
||||
server_app,
|
||||
host=Config().websocket_server_host,
|
||||
port=Config().websocket_server_port,
|
||||
log_config=generate_uvicorn_config(),
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.block import BlockInstallationBlock
|
||||
from backend.blocks.http import SendWebRequestBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.blocks.text import ExtractTextInformationBlock, FillTextTemplateBlock
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import User
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
|
||||
from backend.blocks.reddit import GetRedditPostsBlock, PostRedditCommentBlock
|
||||
from backend.blocks.text import FillTextTemplateBlock, MatchTextPatternBlock
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import User
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user