Compare commits

..

4 Commits

Author SHA1 Message Date
Swifty
894e3600fb add other specs 2025-08-01 14:21:57 +02:00
Swifty
9de4b09f20 mv to sub dir 2025-08-01 13:19:42 +02:00
Swifty
62e41d409a websocket server running well now 2025-08-01 13:17:45 +02:00
Swifty
9f03e3af47 added websocket service 2025-08-01 11:19:29 +02:00
678 changed files with 36885 additions and 21342 deletions

View File

@@ -15,7 +15,6 @@
!autogpt_platform/backend/pyproject.toml
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
# Platform - Market
!autogpt_platform/market/market/
@@ -28,7 +27,6 @@
# Platform - Frontend
!autogpt_platform/frontend/src/
!autogpt_platform/frontend/public/
!autogpt_platform/frontend/scripts/
!autogpt_platform/frontend/package.json
!autogpt_platform/frontend/pnpm-lock.yaml
!autogpt_platform/frontend/tsconfig.json
@@ -36,7 +34,6 @@
## config
!autogpt_platform/frontend/*.config.*
!autogpt_platform/frontend/.env.*
!autogpt_platform/frontend/.env
# Classic - AutoGPT
!classic/original_autogpt/autogpt/

View File

@@ -24,8 +24,7 @@
</details>
#### For configuration changes:
- [ ] `.env.default` is updated or already compatible with my changes
- [ ] `.env.example` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my changes
- [ ] I have included a list of my configuration changes in the PR description (under **Changes**)

View File

@@ -82,6 +82,37 @@ jobs:
- name: Run lint
run: pnpm lint
type-check:
runs-on: ubuntu-latest
needs: setup
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Run tsc check
run: pnpm type-check
chromatic:
runs-on: ubuntu-latest
needs: setup
@@ -145,7 +176,11 @@ jobs:
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
cp ../.env.example ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.example ../backend/.env
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -217,6 +252,15 @@ jobs:
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup .env
run: cp .env.example .env
- name: Build frontend
run: pnpm build --turbo
# uses Turbopack, much faster and safe enough for a test pipeline
env:
NEXT_PUBLIC_PW_TEST: true
- name: Install Browser 'chromium'
run: pnpm playwright install --with-deps chromium

View File

@@ -1,132 +0,0 @@
name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master, dev]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
pull_request:
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
merge_group:
defaults:
run:
shell: bash
working-directory: autogpt_platform/frontend
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
types:
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.default ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup .env
run: cp .env.default .env
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Generate API queries
run: pnpm generate:api:force
- name: Check for API schema changes
run: |
if ! git diff --exit-code src/app/api/openapi.json; then
echo "❌ API schema changes detected in src/app/api/openapi.json"
echo ""
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
echo "The API schema is now out of sync with the Front-end queries."
echo ""
echo "To fix this:"
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
echo "2. Run 'pnpm generate:api' locally"
echo "3. Run 'pnpm types' locally"
echo "4. Fix any TypeScript errors that may have been introduced"
echo "5. Commit and push your changes"
echo ""
exit 1
else
echo "✅ No API schema changes detected"
fi
- name: Run Typescript checks
run: pnpm types

3
.gitignore vendored
View File

@@ -5,8 +5,6 @@ classic/original_autogpt/*.json
auto_gpt_workspace/*
*.mpeg
.env
# Root .env files
/.env
azure.yaml
.vscode
.idea/*
@@ -123,6 +121,7 @@ celerybeat.pid
# Environments
.direnv/
.env
.venv
env/
venv*/

View File

@@ -235,7 +235,7 @@ repos:
hooks:
- id: tsc
name: Typecheck - AutoGPT Platform - Frontend
entry: bash -c 'cd autogpt_platform/frontend && pnpm types'
entry: bash -c 'cd autogpt_platform/frontend && pnpm type-check'
files: ^autogpt_platform/frontend/
types: [file]
language: system

View File

@@ -3,16 +3,6 @@
[![Discord Follow](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fautogpt%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&label=total%20members&logo=discord&logoColor=white&color=7289da)](https://discord.gg/autogpt) &ensp;
[![Twitter Follow](https://img.shields.io/twitter/follow/Auto_GPT?style=social)](https://twitter.com/Auto_GPT) &ensp;
<!-- Keep these links. Translations will automatically update with the README. -->
[Deutsch](https://zdoc.app/de/Significant-Gravitas/AutoGPT) |
[Español](https://zdoc.app/es/Significant-Gravitas/AutoGPT) |
[français](https://zdoc.app/fr/Significant-Gravitas/AutoGPT) |
[日本語](https://zdoc.app/ja/Significant-Gravitas/AutoGPT) |
[한국어](https://zdoc.app/ko/Significant-Gravitas/AutoGPT) |
[Português](https://zdoc.app/pt/Significant-Gravitas/AutoGPT) |
[Русский](https://zdoc.app/ru/Significant-Gravitas/AutoGPT) |
[中文](https://zdoc.app/zh/Significant-Gravitas/AutoGPT)
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
## Hosting Options

View File

@@ -1,11 +1,9 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`/backend`): Python FastAPI server with async support
- **Frontend** (`/frontend`): Next.js React application
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
@@ -13,7 +11,6 @@ AutoGPT Platform is a monorepo containing:
## Essential Commands
### Backend Development
```bash
# Install dependencies
cd backend && poetry install
@@ -33,18 +30,11 @@ poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in TESTING.md
#### Creating/Updating Snapshots
@@ -57,8 +47,8 @@ poetry run pytest path/to/test.py --snapshot-update
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
### Frontend Development
### Frontend Development
```bash
# Install dependencies
cd frontend && npm install
@@ -76,13 +66,12 @@ npm run storybook
npm run build
# Type checking
npm run types
npm run type-check
```
## Architecture Overview
### Backend Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
@@ -91,7 +80,6 @@ npm run types
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
### Frontend Architecture
- **Framework**: Next.js App Router with React Server Components
- **State Management**: React hooks + Supabase client for real-time updates
- **Workflow Builder**: Visual graph editor using @xyflow/react
@@ -99,7 +87,6 @@ npm run types
- **Feature Flags**: LaunchDarkly integration
### Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
@@ -107,16 +94,13 @@ npm run types
5. **Virus Scanning**: ClamAV integration for file upload security
### Testing Approach
- Backend uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Frontend uses Playwright for E2E tests
- Component testing via Storybook
### Database Schema
Key models (defined in `/backend/schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
@@ -124,31 +108,13 @@ Key models (defined in `/backend/schema.prisma`):
- `StoreListing`: Marketplace listings for sharing agents
### Environment Configuration
#### Configuration Files
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
- Backend: `.env` file in `/backend`
- Frontend: `.env.local` file in `/frontend`
- Both require Supabase credentials and API keys for various services
### Common Development Tasks
**Adding a new block:**
1. Create new file in `/backend/backend/blocks/`
2. Inherit from `Block` base class
3. Define input/output schemas
@@ -156,18 +122,13 @@ Key models (defined in `/backend/schema.prisma`):
5. Register in block registry
6. Generate the block uuid using `uuid.uuid4()`
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
**Frontend feature development:**
1. Components go in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for new components
@@ -176,7 +137,6 @@ ex: do the inputs and outputs tie well together?
### Security Implementation
**Cache Protection Middleware:**
- Located in `/backend/backend/server/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
@@ -184,47 +144,3 @@ ex: do the inputs and outputs tie well together?
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
### Creating Pull Requests
- Create the PR aginst the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
- Use conventional commit messages (see below)/
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
- Run the github pre-commit hooks to ensure code quality.
### Reviewing/Revising Pull Requests
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -8,6 +8,7 @@ Welcome to the AutoGPT Platform - a powerful system for creating and running AI
- Docker
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
- Node.js & NPM (for running the frontend application)
### Running the System
@@ -23,10 +24,10 @@ To run the AutoGPT Platform, follow these steps:
2. Run the following command:
```
cp .env.default .env
cp .env.example .env
```
This command will copy the `.env.default` file to `.env`. You can modify the `.env` file to add your own environment variables.
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
3. Run the following command:
@@ -36,7 +37,44 @@ To run the AutoGPT Platform, follow these steps:
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
4. Navigate to `frontend` within the `autogpt_platform` directory:
```
cd frontend
```
You will need to run your frontend application separately on your local machine.
5. Run the following command:
```
cp .env.example .env.local
```
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
6. Run the following command:
Enable corepack and install dependencies by running:
```
corepack enable
pnpm i
```
Generate the API client (this step is required before running the frontend):
```
pnpm generate:api-client
```
Then start the frontend application in development mode:
```
pnpm dev
```
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Docker Compose Commands
@@ -139,21 +177,20 @@ The platform includes scripts for generating and managing the API client:
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
- `pnpm generate:api`: Runs both fetch and generate commands in sequence
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
#### Manual API Client Updates
If you need to update the API client after making changes to the backend API:
1. Ensure the backend services are running:
```
docker compose up -d
```
2. Generate the updated API client:
```
pnpm generate:api
pnpm generate:api-all
```
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.

View File

@@ -0,0 +1,802 @@
# DatabaseManager Technical Specification
## Executive Summary
This document provides a complete technical specification for implementing a drop-in replacement for the AutoGPT Platform's DatabaseManager service. The replacement must maintain 100% API compatibility while preserving all functional behaviors, security requirements, and performance characteristics.
## 1. System Overview
### 1.1 Purpose
The DatabaseManager is a centralized service that provides database access for the AutoGPT Platform's executor system. It encapsulates all database operations behind a service interface, enabling distributed execution while maintaining data consistency and security.
### 1.2 Architecture Pattern
- **Service Type**: HTTP-based microservice using FastAPI
- **Communication**: RPC-style over HTTP with JSON serialization
- **Base Class**: Inherits from `AppService` (backend.util.service)
- **Client Classes**: `DatabaseManagerClient` (sync) and `DatabaseManagerAsyncClient` (async)
- **Port**: Configurable via `config.database_api_port`
### 1.3 Critical Requirements
1. **API Compatibility**: All 40+ exposed methods must maintain exact signatures
2. **Type Safety**: Full type preservation across service boundaries
3. **User Isolation**: All operations must respect user_id boundaries
4. **Transaction Support**: Maintain ACID properties for critical operations
5. **Event Publishing**: Maintain Redis event bus integration for real-time updates
## 2. Service Implementation Requirements
### 2.1 Base Service Class
```python
from backend.util.service import AppService, expose
from backend.util.settings import Config
from backend.data import db
import logging
class DatabaseManager(AppService):
"""
REQUIRED: Inherit from AppService to get:
- Automatic endpoint generation via @expose decorator
- Built-in health checks at /health
- Request/response serialization
- Error handling and logging
"""
def run_service(self) -> None:
"""REQUIRED: Initialize database connection before starting service"""
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect()) # CRITICAL: Must connect to database
super().run_service() # Start HTTP server
def cleanup(self):
"""REQUIRED: Clean disconnect on shutdown"""
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect()) # CRITICAL: Must disconnect cleanly
@classmethod
def get_port(cls) -> int:
"""REQUIRED: Return configured port"""
return config.database_api_port
```
### 2.2 Method Exposure Pattern
```python
@staticmethod
def _(f: Callable[P, R], name: str | None = None) -> Callable[Concatenate[object, P], R]:
"""
REQUIRED: Helper to expose methods with proper signatures
- Preserves function name for endpoint generation
- Maintains type information
- Adds 'self' parameter for instance binding
"""
if name is not None:
f.__name__ = name
return cast(Callable[Concatenate[object, P], R], expose(f))
```
### 2.3 Database Connection Management
**REQUIRED: Use Prisma ORM with these exact configurations:**
```python
from prisma import Prisma
prisma = Prisma(
auto_register=True,
http={"timeout": HTTP_TIMEOUT}, # Default: 120 seconds
datasource={"url": DATABASE_URL}
)
# Connection lifecycle
async def connect():
await prisma.connect()
async def disconnect():
await prisma.disconnect()
```
### 2.4 Transaction Support
**REQUIRED: Implement both regular and locked transactions:**
```python
async def transaction(timeout: float | None = None):
"""Regular database transaction"""
async with prisma.tx(timeout=timeout) as tx:
yield tx
async def locked_transaction(key: str, timeout: float | None = None):
"""Transaction with PostgreSQL advisory lock"""
lock_key = zlib.crc32(key.encode("utf-8"))
async with transaction(timeout=timeout) as tx:
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
yield tx
```
## 3. Complete API Specification
### 3.1 Execution Management APIs
#### get_graph_execution
```python
async def get_graph_execution(
user_id: str,
execution_id: str,
*,
include_node_executions: bool = False
) -> GraphExecution | GraphExecutionWithNodes | None
```
**Behavior**:
- Returns execution only if user_id matches
- Optionally includes all node executions
- Returns None if not found or unauthorized
#### get_graph_executions
```python
async def get_graph_executions(
user_id: str,
graph_id: str | None = None,
*,
limit: int = 50,
graph_version: int | None = None,
cursor: str | None = None,
preset_id: str | None = None
) -> tuple[list[GraphExecution], str | None]
```
**Behavior**:
- Paginated results with cursor
- Filter by graph_id, version, or preset_id
- Returns (executions, next_cursor)
#### create_graph_execution
```python
async def create_graph_execution(
graph_id: str,
graph_version: int,
starting_nodes_input: dict[str, dict[str, Any]],
user_id: str,
preset_id: str | None = None
) -> GraphExecutionWithNodes
```
**Behavior**:
- Creates execution with status "QUEUED"
- Initializes all nodes with "PENDING" status
- Publishes creation event to Redis
- Uses locked transaction on graph_id
#### update_graph_execution_start_time
```python
async def update_graph_execution_start_time(
graph_exec_id: str
) -> None
```
**Behavior**:
- Sets start_time to current timestamp
- Only updates if currently NULL
#### update_graph_execution_stats
```python
async def update_graph_execution_stats(
graph_exec_id: str,
status: AgentExecutionStatus | None = None,
stats: dict[str, Any] | None = None
) -> GraphExecution | None
```
**Behavior**:
- Updates status and/or stats atomically
- Sets end_time if status is terminal (COMPLETED/FAILED)
- Publishes update event to Redis
- Returns updated execution
#### get_node_execution
```python
async def get_node_execution(
node_exec_id: str
) -> NodeExecutionResult | None
```
**Behavior**:
- No user_id check (relies on graph execution security)
- Includes all input/output data
#### get_node_executions
```python
async def get_node_executions(
graph_exec_id: str
) -> list[NodeExecutionResult]
```
**Behavior**:
- Returns all node executions for graph
- Ordered by creation time
#### get_latest_node_execution
```python
async def get_latest_node_execution(
graph_exec_id: str,
node_id: str
) -> NodeExecutionResult | None
```
**Behavior**:
- Returns most recent execution of specific node
- Used for retry/rerun scenarios
#### update_node_execution_status
```python
async def update_node_execution_status(
node_exec_id: str,
status: AgentExecutionStatus,
execution_data: dict[str, Any] | None = None,
stats: dict[str, Any] | None = None
) -> NodeExecutionResult
```
**Behavior**:
- Updates status atomically
- Sets end_time for terminal states
- Optionally updates stats/data
- Publishes event to Redis
- Returns updated execution
#### update_node_execution_status_batch
```python
async def update_node_execution_status_batch(
execution_updates: list[NodeExecutionUpdate]
) -> list[NodeExecutionResult]
```
**Behavior**:
- Batch update multiple nodes in single transaction
- Each update can have different status/stats
- Publishes events for all updates
- Returns all updated executions
#### update_node_execution_stats
```python
async def update_node_execution_stats(
node_exec_id: str,
stats: dict[str, Any]
) -> NodeExecutionResult
```
**Behavior**:
- Updates only stats field
- Merges with existing stats
- Does not affect status
#### upsert_execution_input
```python
async def upsert_execution_input(
node_id: str,
graph_exec_id: str,
input_name: str,
input_data: Any,
node_exec_id: str | None = None
) -> tuple[str, BlockInput]
```
**Behavior**:
- Creates or updates input data
- If node_exec_id not provided, creates node execution
- Serializes input_data to JSON
- Returns (node_exec_id, input_object)
#### upsert_execution_output
```python
async def upsert_execution_output(
node_exec_id: str,
output_name: str,
output_data: Any
) -> None
```
**Behavior**:
- Creates or updates output data
- Serializes output_data to JSON
- No return value
#### get_execution_kv_data
```python
async def get_execution_kv_data(
user_id: str,
key: str
) -> Any | None
```
**Behavior**:
- User-scoped key-value storage
- Returns deserialized JSON data
- Returns None if key not found
#### set_execution_kv_data
```python
async def set_execution_kv_data(
user_id: str,
node_exec_id: str,
key: str,
data: Any
) -> Any | None
```
**Behavior**:
- Sets user-scoped key-value data
- Associates with node execution
- Serializes data to JSON
- Returns previous value or None
#### get_block_error_stats
```python
async def get_block_error_stats() -> list[BlockErrorStats]
```
**Behavior**:
- Aggregates error counts by block_id
- Last 7 days of data
- Groups by error type
### 3.2 Graph Management APIs
#### get_node
```python
async def get_node(
node_id: str
) -> AgentNode | None
```
**Behavior**:
- Returns node with block data
- No user_id check (public blocks)
#### get_graph
```python
async def get_graph(
graph_id: str,
version: int | None = None,
user_id: str | None = None,
for_export: bool = False,
include_subgraphs: bool = False
) -> GraphModel | None
```
**Behavior**:
- Returns latest version if version=None
- Checks user_id for private graphs
- for_export=True excludes internal fields
- include_subgraphs=True loads nested graphs
#### get_connected_output_nodes
```python
async def get_connected_output_nodes(
node_id: str,
output_name: str
) -> list[tuple[AgentNode, AgentNodeLink]]
```
**Behavior**:
- Returns downstream nodes connected to output
- Includes link metadata
- Used for execution flow
#### get_graph_metadata
```python
async def get_graph_metadata(
graph_id: str,
user_id: str
) -> GraphMetadata | None
```
**Behavior**:
- Returns graph metadata without full definition
- User must own or have access to graph
### 3.3 Credit System APIs
#### get_credits
```python
async def get_credits(
user_id: str
) -> int
```
**Behavior**:
- Returns current credit balance
- Always non-negative
#### spend_credits
```python
async def spend_credits(
user_id: str,
cost: int,
metadata: UsageTransactionMetadata
) -> int
```
**Behavior**:
- Deducts credits atomically
- Creates transaction record
- Throws InsufficientCredits if balance too low
- Returns new balance
- metadata includes: block_id, node_exec_id, context
### 3.4 User Management APIs
#### get_user_metadata
```python
async def get_user_metadata(
user_id: str
) -> UserMetadata
```
**Behavior**:
- Returns user preferences and settings
- Creates default if not exists
#### update_user_metadata
```python
async def update_user_metadata(
user_id: str,
data: UserMetadataDTO
) -> UserMetadata
```
**Behavior**:
- Partial update of metadata
- Validates against schema
- Returns updated metadata
#### get_user_integrations
```python
async def get_user_integrations(
user_id: str
) -> UserIntegrations
```
**Behavior**:
- Returns OAuth credentials
- Decrypts sensitive data
- Creates empty if not exists
#### update_user_integrations
```python
async def update_user_integrations(
user_id: str,
data: UserIntegrations
) -> None
```
**Behavior**:
- Updates integration credentials
- Encrypts sensitive data
- No return value
### 3.5 User Communication APIs
#### get_active_user_ids_in_timerange
```python
async def get_active_user_ids_in_timerange(
start_time: datetime,
end_time: datetime
) -> list[str]
```
**Behavior**:
- Returns users with graph executions in range
- Used for analytics/notifications
#### get_user_email_by_id
```python
async def get_user_email_by_id(
user_id: str
) -> str | None
```
**Behavior**:
- Returns user's email address
- None if user not found
#### get_user_email_verification
```python
async def get_user_email_verification(
user_id: str
) -> UserEmailVerification
```
**Behavior**:
- Returns email and verification status
- Used for notification filtering
#### get_user_notification_preference
```python
async def get_user_notification_preference(
user_id: str
) -> NotificationPreference
```
**Behavior**:
- Returns notification settings
- Creates default if not exists
### 3.6 Notification APIs
#### create_or_add_to_user_notification_batch
```python
async def create_or_add_to_user_notification_batch(
user_id: str,
notification_type: NotificationType,
notification_data: NotificationEvent
) -> UserNotificationBatchDTO
```
**Behavior**:
- Adds to existing batch or creates new
- Batches by type for efficiency
- Returns updated batch
#### empty_user_notification_batch
```python
async def empty_user_notification_batch(
user_id: str,
notification_type: NotificationType
) -> None
```
**Behavior**:
- Clears all notifications of type
- Used after sending batch
#### get_all_batches_by_type
```python
async def get_all_batches_by_type(
notification_type: NotificationType
) -> list[UserNotificationBatchDTO]
```
**Behavior**:
- Returns all user batches of type
- Used by notification service
#### get_user_notification_batch
```python
async def get_user_notification_batch(
user_id: str,
notification_type: NotificationType
) -> UserNotificationBatchDTO | None
```
**Behavior**:
- Returns user's batch for type
- None if no batch exists
#### get_user_notification_oldest_message_in_batch
```python
async def get_user_notification_oldest_message_in_batch(
user_id: str,
notification_type: NotificationType
) -> NotificationEvent | None
```
**Behavior**:
- Returns oldest notification in batch
- Used for batch timing decisions
## 4. Client Implementation Requirements
### 4.1 Synchronous Client
```python
class DatabaseManagerClient(AppServiceClient):
"""
REQUIRED: Synchronous client that:
- Converts async methods to sync using endpoint_to_sync
- Maintains exact method signatures
- Handles connection pooling
- Implements retry logic
"""
@classmethod
def get_service_type(cls):
return DatabaseManager
# Example method mapping
get_graph_execution = endpoint_to_sync(DatabaseManager.get_graph_execution)
```
### 4.2 Asynchronous Client
```python
class DatabaseManagerAsyncClient(AppServiceClient):
"""
REQUIRED: Async client that:
- Directly references async methods
- No conversion needed
- Shares connection pool
"""
@classmethod
def get_service_type(cls):
return DatabaseManager
# Direct method reference
get_graph_execution = DatabaseManager.get_graph_execution
```
## 5. Data Models
### 5.1 Core Enums
```python
class AgentExecutionStatus(str, Enum):
PENDING = "PENDING"
QUEUED = "QUEUED"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
CANCELED = "CANCELED"
class NotificationType(str, Enum):
SYSTEM = "SYSTEM"
REVIEW = "REVIEW"
EXECUTION = "EXECUTION"
MARKETING = "MARKETING"
```
### 5.2 Key Data Models
All models must exactly match the Prisma schema definitions. Key models include:
- `GraphExecution`: Execution metadata with stats
- `GraphExecutionWithNodes`: Includes all node executions
- `NodeExecutionResult`: Node execution with I/O data
- `GraphModel`: Complete graph definition
- `UserIntegrations`: OAuth credentials
- `UsageTransactionMetadata`: Credit usage context
- `NotificationEvent`: Individual notification data
## 6. Security Requirements
### 6.1 User Isolation
- **CRITICAL**: All user-scoped operations MUST filter by user_id
- Never expose data across user boundaries
- Use database-level row security where possible
### 6.2 Authentication
- Service assumes authentication handled by API gateway
- user_id parameter is trusted after authentication
- No additional auth checks within service
### 6.3 Data Protection
- Encrypt sensitive integration credentials
- Use HMAC for unsubscribe tokens
- Never log sensitive data
## 7. Performance Requirements
### 7.1 Connection Management
- Maintain persistent database connection
- Use connection pooling (default: 10 connections)
- Implement exponential backoff for retries
### 7.2 Query Optimization
- Use indexes for all WHERE clauses
- Batch operations where possible
- Limit default result sets (50 items)
### 7.3 Event Publishing
- Publish events asynchronously
- Don't block on event delivery
- Use fire-and-forget pattern
## 8. Error Handling
### 8.1 Standard Exceptions
```python
class InsufficientCredits(Exception):
"""Raised when user lacks credits"""
class NotFoundError(Exception):
"""Raised when entity not found"""
class AuthorizationError(Exception):
"""Raised when user lacks access"""
```
### 8.2 Error Response Format
```json
{
"error": "error_type",
"message": "Human readable message",
"details": {} // Optional additional context
}
```
## 9. Testing Requirements
### 9.1 Unit Tests
- Test each method in isolation
- Mock database calls
- Verify user_id filtering
### 9.2 Integration Tests
- Test with real database
- Verify transaction boundaries
- Test concurrent operations
### 9.3 Service Tests
- Test HTTP endpoint generation
- Verify serialization/deserialization
- Test error handling
## 10. Implementation Checklist
### Phase 1: Core Service Setup
- [ ] Create DatabaseManager class inheriting from AppService
- [ ] Implement run_service() with database connection
- [ ] Implement cleanup() with proper disconnect
- [ ] Configure port from settings
- [ ] Set up method exposure helper
### Phase 2: Execution APIs (15 methods)
- [ ] get_graph_execution
- [ ] get_graph_executions
- [ ] get_graph_execution_meta
- [ ] create_graph_execution
- [ ] update_graph_execution_start_time
- [ ] update_graph_execution_stats
- [ ] get_node_execution
- [ ] get_node_executions
- [ ] get_latest_node_execution
- [ ] update_node_execution_status
- [ ] update_node_execution_status_batch
- [ ] update_node_execution_stats
- [ ] upsert_execution_input
- [ ] upsert_execution_output
- [ ] get_execution_kv_data
- [ ] set_execution_kv_data
- [ ] get_block_error_stats
### Phase 3: Graph APIs (4 methods)
- [ ] get_node
- [ ] get_graph
- [ ] get_connected_output_nodes
- [ ] get_graph_metadata
### Phase 4: Credit APIs (2 methods)
- [ ] get_credits
- [ ] spend_credits
### Phase 5: User APIs (4 methods)
- [ ] get_user_metadata
- [ ] update_user_metadata
- [ ] get_user_integrations
- [ ] update_user_integrations
### Phase 6: Communication APIs (4 methods)
- [ ] get_active_user_ids_in_timerange
- [ ] get_user_email_by_id
- [ ] get_user_email_verification
- [ ] get_user_notification_preference
### Phase 7: Notification APIs (5 methods)
- [ ] create_or_add_to_user_notification_batch
- [ ] empty_user_notification_batch
- [ ] get_all_batches_by_type
- [ ] get_user_notification_batch
- [ ] get_user_notification_oldest_message_in_batch
### Phase 8: Client Implementation
- [ ] Create DatabaseManagerClient with sync methods
- [ ] Create DatabaseManagerAsyncClient with async methods
- [ ] Test client method generation
- [ ] Verify type preservation
### Phase 9: Integration Testing
- [ ] Test all methods with real database
- [ ] Verify user isolation
- [ ] Test error scenarios
- [ ] Performance testing
- [ ] Event publishing verification
### Phase 10: Deployment Validation
- [ ] Deploy to test environment
- [ ] Run integration test suite
- [ ] Verify backward compatibility
- [ ] Performance benchmarking
- [ ] Production deployment
## 11. Success Criteria
The implementation is successful when:
1. **All 40+ methods** produce identical outputs to the original
2. **Performance** is within 10% of original implementation
3. **All tests** pass without modification
4. **No breaking changes** to any client code
5. **Security boundaries** are maintained
6. **Event publishing** works identically
7. **Error handling** matches original behavior
## 12. Critical Implementation Notes
1. **DO NOT** modify any function signatures
2. **DO NOT** change any return types
3. **DO NOT** add new required parameters
4. **DO NOT** remove any functionality
5. **ALWAYS** maintain user_id isolation
6. **ALWAYS** publish events for state changes
7. **ALWAYS** use transactions for multi-step operations
8. **ALWAYS** handle errors exactly as original
This specification, when implemented correctly, will produce a drop-in replacement for the DatabaseManager that maintains 100% compatibility with the existing system.

View File

@@ -0,0 +1,765 @@
# Notification Service Technical Specification
## Overview
The AutoGPT Platform Notification Service is a RabbitMQ-based asynchronous notification system that handles various types of user notifications including real-time alerts, batched notifications, and scheduled summaries. The service supports email delivery via Postmark and system alerts via Discord.
## Architecture Overview
### Core Components
1. **NotificationManager Service** (`notifications.py`)
- AppService implementation with RabbitMQ integration
- Processes notification queues asynchronously
- Manages batching strategies and delivery timing
- Handles email templating and sending
2. **RabbitMQ Message Broker**
- Multiple queues for different notification strategies
- Dead letter exchange for failed messages
- Topic-based routing for message distribution
3. **Email Sender** (`email.py`)
- Postmark integration for email delivery
- Jinja2 template rendering
- HTML email composition with unsubscribe headers
4. **Database Storage**
- Notification batching tables
- User preference storage
- Email verification tracking
## Service Exposure Mechanism
### AppService Framework
The NotificationManager extends `AppService` which automatically exposes methods decorated with `@expose` as HTTP endpoints:
```python
class NotificationManager(AppService):
@expose
def queue_weekly_summary(self):
# Implementation
@expose
def process_existing_batches(self, notification_types: list[NotificationType]):
# Implementation
@expose
async def discord_system_alert(self, content: str):
# Implementation
```
### Automatic HTTP Endpoint Creation
When the service starts, the AppService base class:
1. Scans for methods with `@expose` decorator
2. Creates FastAPI routes for each exposed method:
- Route path: `/{method_name}`
- HTTP method: POST
- Endpoint handler: Generated via `_create_fastapi_endpoint()`
### Service Client Access
#### NotificationManagerClient
```python
class NotificationManagerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return NotificationManager
# Direct method references (sync)
process_existing_batches = NotificationManager.process_existing_batches
queue_weekly_summary = NotificationManager.queue_weekly_summary
# Async-to-sync conversion
discord_system_alert = endpoint_to_sync(NotificationManager.discord_system_alert)
```
#### Client Usage Pattern
```python
# Get client instance
client = get_service_client(NotificationManagerClient)
# Call exposed methods via HTTP
client.process_existing_batches([NotificationType.AGENT_RUN])
client.queue_weekly_summary()
client.discord_system_alert("System alert message")
```
### HTTP Communication Details
1. **Service URL**: `http://{host}:{notification_service_port}`
- Default port: 8007
- Host: Configurable via settings
2. **Request Format**:
- Method: POST
- Path: `/{method_name}`
- Body: JSON with method parameters
3. **Client Implementation**:
- Uses `httpx` for HTTP requests
- Automatic retry on connection failures
- Configurable timeout (default from api_call_timeout)
### Direct Function Calls
The service also exposes two functions that can be called directly without going through the service client:
```python
# Sync version - used by ExecutionManager
def queue_notification(event: NotificationEventModel) -> NotificationResult
# Async version - used by credit system
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult
```
These functions:
- Connect directly to RabbitMQ
- Publish messages to appropriate queues
- Return success/failure status
- Are NOT exposed via HTTP
## Message Queuing Architecture
### RabbitMQ Configuration
#### Exchanges
```python
NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
```
#### Queues
1. **immediate_notifications**
- Routing Key: `notification.immediate.#`
- Dead Letter: `failed.immediate`
- For: Critical alerts, errors
2. **admin_notifications**
- Routing Key: `notification.admin.#`
- Dead Letter: `failed.admin`
- For: Refund requests, system alerts
3. **summary_notifications**
- Routing Key: `notification.summary.#`
- Dead Letter: `failed.summary`
- For: Daily/weekly summaries
4. **batch_notifications**
- Routing Key: `notification.batch.#`
- Dead Letter: `failed.batch`
- For: Agent runs, batched events
5. **failed_notifications**
- Routing Key: `failed.#`
- For: All failed messages
### Queue Strategies (QueueType enum)
1. **IMMEDIATE**: Send right away (errors, critical notifications)
2. **BATCH**: Batch for configured delay (agent runs)
3. **SUMMARY**: Scheduled digest (daily/weekly summaries)
4. **BACKOFF**: Exponential backoff strategy (defined but not fully implemented)
5. **ADMIN**: Admin-only notifications
## Notification Types
### Enum Values (NotificationType)
```python
AGENT_RUN # Batch strategy, 1 day delay
ZERO_BALANCE # Backoff strategy, 60 min delay
LOW_BALANCE # Immediate strategy
BLOCK_EXECUTION_FAILED # Backoff strategy, 60 min delay
CONTINUOUS_AGENT_ERROR # Backoff strategy, 60 min delay
DAILY_SUMMARY # Summary strategy
WEEKLY_SUMMARY # Summary strategy
MONTHLY_SUMMARY # Summary strategy
REFUND_REQUEST # Admin strategy
REFUND_PROCESSED # Admin strategy
```
## Integration Points
### 1. Scheduler Integration
The scheduler service (`backend.executor.scheduler`) imports monitoring functions that call the NotificationManagerClient:
```python
from backend.monitoring import (
process_existing_batches,
process_weekly_summary,
)
# These are scheduled as cron jobs
```
### 2. Execution Manager Integration
The ExecutionManager directly calls `queue_notification()` for:
- Agent run completions
- Low balance alerts
```python
from backend.notifications.notifications import queue_notification
# Called after graph execution completes
queue_notification(NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(...)
))
```
### 3. Credit System Integration
The credit system uses `queue_notification_async()` for:
- Refund requests
- Refund processed notifications
```python
from backend.notifications.notifications import queue_notification_async
await queue_notification_async(NotificationEventModel(
user_id=user_id,
type=NotificationType.REFUND_REQUEST,
data=RefundRequestData(...)
))
```
### 4. Monitoring Module Wrappers
The monitoring module provides wrapper functions that are used by the scheduler:
```python
# backend/monitoring/notification_monitor.py
def process_existing_batches(**kwargs):
args = NotificationJobArgs(**kwargs)
get_notification_manager_client().process_existing_batches(
args.notification_types
)
def process_weekly_summary(**kwargs):
get_notification_manager_client().queue_weekly_summary()
```
## Data Models
### Base Event Model
```typescript
interface BaseEventModel {
type: NotificationType;
user_id: string;
created_at: string; // ISO datetime with timezone
}
```
### Notification Event Model
```typescript
interface NotificationEventModel<T> extends BaseEventModel {
data: T;
}
```
### Notification Data Types
#### AgentRunData
```typescript
interface AgentRunData {
agent_name: string;
credits_used: number;
execution_time: number;
node_count: number;
graph_id: string;
outputs: Array<Record<string, any>>;
}
```
#### ZeroBalanceData
```typescript
interface ZeroBalanceData {
last_transaction: number;
last_transaction_time: string; // ISO datetime with timezone
top_up_link: string;
}
```
#### LowBalanceData
```typescript
interface LowBalanceData {
agent_name: string;
current_balance: number; // credits (100 = $1)
billing_page_link: string;
shortfall: number;
}
```
#### BlockExecutionFailedData
```typescript
interface BlockExecutionFailedData {
block_name: string;
block_id: string;
error_message: string;
graph_id: string;
node_id: string;
execution_id: string;
}
```
#### ContinuousAgentErrorData
```typescript
interface ContinuousAgentErrorData {
agent_name: string;
error_message: string;
graph_id: string;
execution_id: string;
start_time: string; // ISO datetime with timezone
error_time: string; // ISO datetime with timezone
attempts: number;
}
```
#### Summary Data Types
```typescript
interface BaseSummaryData {
total_credits_used: number;
total_executions: number;
most_used_agent: string;
total_execution_time: number;
successful_runs: number;
failed_runs: number;
average_execution_time: number;
cost_breakdown: Record<string, number>;
}
interface DailySummaryData extends BaseSummaryData {
date: string; // ISO datetime with timezone
}
interface WeeklySummaryData extends BaseSummaryData {
start_date: string; // ISO datetime with timezone
end_date: string; // ISO datetime with timezone
}
```
#### RefundRequestData
```typescript
interface RefundRequestData {
user_id: string;
user_name: string;
user_email: string;
transaction_id: string;
refund_request_id: string;
reason: string;
amount: number;
balance: number;
}
```
### Summary Parameters
```typescript
interface BaseSummaryParams {
start_date: string; // ISO datetime with timezone
end_date: string; // ISO datetime with timezone
}
interface DailySummaryParams extends BaseSummaryParams {
date: string; // ISO datetime with timezone
}
interface WeeklySummaryParams extends BaseSummaryParams {
start_date: string; // ISO datetime with timezone
end_date: string; // ISO datetime with timezone
}
```
## Database Schema
### NotificationEvent Table
```sql
model NotificationEvent {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
UserNotificationBatch UserNotificationBatch? @relation
userNotificationBatchId String?
type NotificationType
data Json
@@index([userNotificationBatchId])
}
```
### UserNotificationBatch Table
```sql
model UserNotificationBatch {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String
User User @relation
type NotificationType
Notifications NotificationEvent[]
@@unique([userId, type])
}
```
## API Methods
### Exposed Service Methods (via HTTP)
#### queue_weekly_summary()
- **HTTP Endpoint**: `POST /queue_weekly_summary`
- **Purpose**: Triggers weekly summary generation for all active users
- **Process**:
1. Runs in background executor
2. Queries users active in last 7 days
3. Queues summary notification for each user
- **Used by**: Scheduler service (via cron)
#### process_existing_batches(notification_types: list[NotificationType])
- **HTTP Endpoint**: `POST /process_existing_batches`
- **Purpose**: Processes aged-out batches for specified notification types
- **Process**:
1. Runs in background executor
2. Retrieves all batches for given types
3. Checks if oldest message exceeds max delay
4. Sends batched email if aged out
5. Clears processed batches
- **Used by**: Scheduler service (via cron)
#### discord_system_alert(content: str)
- **HTTP Endpoint**: `POST /discord_system_alert`
- **Purpose**: Sends system alerts to Discord channel
- **Async**: Yes (converted to sync by client)
- **Used by**: Monitoring services
### Direct Queue Functions (not via HTTP)
#### queue_notification(event: NotificationEventModel) -> NotificationResult
- **Purpose**: Queue a notification (sync version)
- **Used by**: ExecutionManager (same process)
- **Direct RabbitMQ**: Yes
#### queue_notification_async(event: NotificationEventModel) -> NotificationResult
- **Purpose**: Queue a notification (async version)
- **Used by**: Credit system (async context)
- **Direct RabbitMQ**: Yes
## Message Processing Flow
### 1. Message Routing
```python
def get_routing_key(event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
if strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event_type.value}"
elif strategy == QueueType.BATCH:
return f"notification.batch.{event_type.value}"
# ... etc
```
### 2. Queue Processing Methods
#### _process_immediate(message: str) -> bool
1. Parse message to NotificationEventModel
2. Retrieve user email
3. Check user preferences and email verification
4. Send email immediately via EmailSender
5. Return True if successful
#### _process_batch(message: str) -> bool
1. Parse message to NotificationEventModel
2. Add to user's notification batch
3. Check if batch is old enough (based on delay)
4. If aged out:
- Retrieve all batch messages
- Send combined email
- Clear batch
5. Return True if processed or batched
#### _process_summary(message: str) -> bool
1. Parse message to SummaryParamsEventModel
2. Gather summary data (credits, executions, etc.)
- **Note**: Currently returns hardcoded placeholder data
3. Format and send summary email
4. Return True if successful
#### _process_admin_message(message: str) -> bool
1. Parse message
2. Send to configured admin email
3. No user preference checks
4. Return True if successful
## Email Delivery
### EmailSender Class
#### Template Loading
- Base template: `templates/base.html.jinja2`
- Notification templates: `templates/{notification_type}.html.jinja2`
- Subject templates from NotificationTypeOverride
- **Note**: Templates use `.html.jinja2` extension, not just `.html`
#### Email Composition
```python
def send_templated(
notification: NotificationType,
user_email: str,
data: NotificationEventModel | list[NotificationEventModel],
user_unsub_link: str | None = None
)
```
#### Postmark Integration
- API Token: `settings.secrets.postmark_server_api_token`
- Sender Email: `settings.config.postmark_sender_email`
- Headers:
- `List-Unsubscribe-Post: List-Unsubscribe=One-Click`
- `List-Unsubscribe: <{unsubscribe_link}>`
## User Preferences and Permissions
### Email Verification Check
```python
validated_email = get_db().get_user_email_verification(user_id)
```
### Notification Preferences
```python
preferences = get_db().get_user_notification_preference(user_id).preferences
# Returns dict[NotificationType, bool]
```
### Preference Fields in User Model
- `notifyOnAgentRun`
- `notifyOnZeroBalance`
- `notifyOnLowBalance`
- `notifyOnBlockExecutionFailed`
- `notifyOnContinuousAgentError`
- `notifyOnDailySummary`
- `notifyOnWeeklySummary`
- `notifyOnMonthlySummary`
### Unsubscribe Link Generation
```python
def generate_unsubscribe_link(user_id: str) -> str:
# HMAC-SHA256 signed token
# Format: base64(user_id:signature_hex)
# URL: {platform_base_url}/api/email/unsubscribe?token={token}
```
## Batching Logic
### Batch Delays (get_batch_delay)
**Note**: The delay configuration exists for multiple notification types, but only notifications with `QueueType.BATCH` strategy actually use batching. Others use different strategies:
- `AGENT_RUN`: 1 day (Strategy: BATCH - actually uses batching)
- `ZERO_BALANCE`: 60 minutes configured (Strategy: BACKOFF - not batched)
- `LOW_BALANCE`: 60 minutes configured (Strategy: IMMEDIATE - sent immediately)
- `BLOCK_EXECUTION_FAILED`: 60 minutes configured (Strategy: BACKOFF - not batched)
- `CONTINUOUS_AGENT_ERROR`: 60 minutes configured (Strategy: BACKOFF - not batched)
### Batch Processing
1. Messages added to UserNotificationBatch
2. Oldest message timestamp tracked
3. When `oldest_timestamp + delay < now()`:
- Batch is processed
- All messages sent in single email
- Batch cleared
## Service Lifecycle
### Startup
1. Initialize FastAPI app with exposed endpoints
2. Start HTTP server on port 8007
3. Initialize RabbitMQ connection
4. Create/verify exchanges and queues
5. Set up queue consumers
6. Start processing loop
### Main Loop
```python
while self.running:
await self._run_queue(immediate_queue, self._process_immediate, ...)
await self._run_queue(admin_queue, self._process_admin_message, ...)
await self._run_queue(batch_queue, self._process_batch, ...)
await self._run_queue(summary_queue, self._process_summary, ...)
await asyncio.sleep(0.1)
```
### Shutdown
1. Set `running = False`
2. Disconnect RabbitMQ
3. Cleanup resources
## Configuration
### Environment Variables
```python
# Service Configuration
notification_service_port: int = 8007
# Email Configuration
postmark_sender_email: str = "invalid@invalid.com"
refund_notification_email: str = "refund@agpt.co"
# Security
unsubscribe_secret_key: str = ""
# Secrets
postmark_server_api_token: str = ""
postmark_webhook_token: str = ""
discord_bot_token: str = ""
# Platform URLs
platform_base_url: str
frontend_base_url: str
```
## Error Handling
### Message Processing Errors
- Failed messages sent to dead letter queue
- Validation errors logged but don't crash service
- Connection errors trigger retry with `@continuous_retry()`
### RabbitMQ ACK/NACK Protocol
- Success: `message.ack()`
- Failure: `message.reject(requeue=False)`
- Timeout/Queue empty: Continue loop
### HTTP Endpoint Errors
- Wrapped in RemoteCallError for client
- Automatic retry available via client configuration
- Connection failures tracked and logged
## System Integrations
### DatabaseManagerClient
- User email retrieval
- Email verification status
- Notification preferences
- Batch management
- Active user queries
### Discord Integration
- Uses SendDiscordMessageBlock
- Configured via discord_bot_token
- For system alerts only
## Implementation Checklist
1. **Core Service**
- [ ] AppService implementation with @expose decorators
- [ ] FastAPI endpoint generation
- [ ] RabbitMQ connection management
- [ ] Queue consumer setup
- [ ] Message routing logic
2. **Service Client**
- [ ] NotificationManagerClient implementation
- [ ] HTTP client configuration
- [ ] Method mapping to service endpoints
- [ ] Async-to-sync conversions
3. **Message Processing**
- [ ] Parse and validate all notification types
- [ ] Implement all queue strategies
- [ ] Batch management with delays
- [ ] Summary data gathering
4. **Email Delivery**
- [ ] Postmark integration
- [ ] Template loading and rendering
- [ ] Unsubscribe header support
- [ ] HTML email composition
5. **User Management**
- [ ] Preference checking
- [ ] Email verification
- [ ] Unsubscribe link generation
- [ ] Daily limit tracking
6. **Batching System**
- [ ] Database batch operations
- [ ] Age-out checking
- [ ] Batch clearing after send
- [ ] Oldest message tracking
7. **Error Handling**
- [ ] Dead letter queue routing
- [ ] Message rejection on failure
- [ ] Continuous retry wrapper
- [ ] Validation error logging
8. **Scheduled Operations**
- [ ] Weekly summary generation
- [ ] Batch processing triggers
- [ ] Background executor usage
## Security Considerations
1. **Service-to-Service Communication**:
- HTTP endpoints only accessible internally
- No authentication on service endpoints (internal network only)
- Service discovery via host/port configuration
2. **User Security**:
- Email verification required for all user notifications
- Unsubscribe tokens HMAC-signed
- User preferences enforced
3. **Admin Notifications**:
- Separate queue, no user preference checks
- Fixed admin email configuration
## Testing Considerations
1. **Unit Tests**
- Message parsing and validation
- Routing key generation
- Batch delay calculations
- Template rendering
2. **Integration Tests**
- HTTP endpoint accessibility
- Service client method calls
- RabbitMQ message flow
- Database batch operations
- Email sending (mock Postmark)
3. **Load Tests**
- High volume message processing
- Concurrent HTTP requests
- Batch accumulation limits
- Memory usage under load
## Implementation Status Notes
1. **Backoff Strategy**: While `QueueType.BACKOFF` is defined and used by several notification types (ZERO_BALANCE, BLOCK_EXECUTION_FAILED, CONTINUOUS_AGENT_ERROR), the actual exponential backoff processing logic is not implemented. These messages are routed to immediate queue.
2. **Summary Data**: The `_gather_summary_data()` method currently returns hardcoded placeholder values rather than querying actual execution data from the database.
3. **Batch Processing**: Only `AGENT_RUN` notifications actually use batch processing. Other notification types with configured delays use different strategies (IMMEDIATE or BACKOFF).
## Future Enhancements
1. **Additional Channels**
- SMS notifications (not implemented)
- Webhook notifications (not implemented)
- In-app notifications
2. **Advanced Batching**
- Dynamic batch sizes
- Priority-based processing
- Custom delay configurations
3. **Analytics**
- Delivery tracking
- Open/click rates
- Notification effectiveness metrics
4. **Service Improvements**
- Authentication for HTTP endpoints
- Rate limiting per user
- Circuit breaker patterns
- Implement actual backoff processing for BACKOFF strategy
- Implement real summary data gathering

View File

@@ -0,0 +1,474 @@
# AutoGPT Platform Scheduler Technical Specification
## Executive Summary
This document provides a comprehensive technical specification for the AutoGPT Platform Scheduler service. The scheduler is responsible for managing scheduled graph executions, system monitoring tasks, and periodic maintenance operations. This specification is designed to enable a complete reimplementation that maintains 100% compatibility with the existing system.
## Table of Contents
1. [System Architecture](#system-architecture)
2. [Service Implementation](#service-implementation)
3. [Data Models](#data-models)
4. [API Endpoints](#api-endpoints)
5. [Database Schema](#database-schema)
6. [External Dependencies](#external-dependencies)
7. [Authentication & Authorization](#authentication--authorization)
8. [Process Management](#process-management)
9. [Error Handling](#error-handling)
10. [Configuration](#configuration)
11. [Testing Strategy](#testing-strategy)
## System Architecture
### Overview
The scheduler operates as an independent microservice within the AutoGPT platform, implementing the `AppService` base class pattern. It runs on a dedicated port (default: 8003) and exposes HTTP/JSON-RPC endpoints for communication with other services.
### Core Components
1. **Scheduler Service** (`backend/executor/scheduler.py:156`)
- Extends `AppService` base class
- Manages APScheduler instance with multiple jobstores
- Handles lifecycle management and graceful shutdown
2. **Scheduler Client** (`backend/executor/scheduler.py:354`)
- Extends `AppServiceClient` base class
- Provides async/sync method wrappers for RPC calls
- Implements automatic retry and connection pooling
3. **Entry Points**
- Main executable: `backend/scheduler.py`
- Service launcher: `backend/app.py`
## Service Implementation
### Base Service Pattern
```python
class Scheduler(AppService):
scheduler: BlockingScheduler
def __init__(self, register_system_tasks: bool = True):
self.register_system_tasks = register_system_tasks
@classmethod
def get_port(cls) -> int:
return config.execution_scheduler_port # Default: 8003
@classmethod
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size # Default: 3
def run_service(self):
# Initialize scheduler with jobstores
# Register system tasks if enabled
# Start scheduler blocking loop
def cleanup(self):
# Graceful shutdown of scheduler
# Wait=False for immediate termination
```
### Jobstore Configuration
The scheduler uses three distinct jobstores:
1. **EXECUTION** (`Jobstores.EXECUTION.value`)
- Type: SQLAlchemyJobStore
- Table: `apscheduler_jobs`
- Purpose: Graph execution schedules
- Persistence: Required
2. **BATCHED_NOTIFICATIONS** (`Jobstores.BATCHED_NOTIFICATIONS.value`)
- Type: SQLAlchemyJobStore
- Table: `apscheduler_jobs_batched_notifications`
- Purpose: Batched notification processing
- Persistence: Required
3. **WEEKLY_NOTIFICATIONS** (`Jobstores.WEEKLY_NOTIFICATIONS.value`)
- Type: MemoryJobStore
- Purpose: Weekly summary notifications
- Persistence: Not required
### System Tasks
When `register_system_tasks=True`, the following monitoring tasks are registered:
1. **Weekly Summary Processing**
- Job ID: `process_weekly_summary`
- Schedule: `0 * * * *` (hourly)
- Function: `monitoring.process_weekly_summary`
- Jobstore: WEEKLY_NOTIFICATIONS
2. **Late Execution Monitoring**
- Job ID: `report_late_executions`
- Schedule: Interval (config.execution_late_notification_threshold_secs)
- Function: `monitoring.report_late_executions`
- Jobstore: EXECUTION
3. **Block Error Rate Monitoring**
- Job ID: `report_block_error_rates`
- Schedule: Interval (config.block_error_rate_check_interval_secs)
- Function: `monitoring.report_block_error_rates`
- Jobstore: EXECUTION
4. **Cloud Storage Cleanup**
- Job ID: `cleanup_expired_files`
- Schedule: Interval (config.cloud_storage_cleanup_interval_hours * 3600)
- Function: `cleanup_expired_files`
- Jobstore: EXECUTION
## Data Models
### GraphExecutionJobArgs
```python
class GraphExecutionJobArgs(BaseModel):
user_id: str
graph_id: str
graph_version: int
cron: str
input_data: BlockInput
input_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
```
### GraphExecutionJobInfo
```python
class GraphExecutionJobInfo(GraphExecutionJobArgs):
id: str
name: str
next_run_time: str
@staticmethod
def from_db(job_args: GraphExecutionJobArgs, job_obj: JobObj) -> "GraphExecutionJobInfo":
return GraphExecutionJobInfo(
id=job_obj.id,
name=job_obj.name,
next_run_time=job_obj.next_run_time.isoformat(),
**job_args.model_dump(),
)
```
### NotificationJobArgs
```python
class NotificationJobArgs(BaseModel):
notification_types: list[NotificationType]
cron: str
```
### CredentialsMetaInput
```python
class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
id: str
title: Optional[str] = None
provider: CP
type: CT
```
## API Endpoints
All endpoints are exposed via the `@expose` decorator and follow HTTP POST JSON-RPC pattern.
### 1. Add Graph Execution Schedule
**Endpoint**: `/add_graph_execution_schedule`
**Request Body**:
```json
{
"user_id": "string",
"graph_id": "string",
"graph_version": "integer",
"cron": "string (crontab format)",
"input_data": {},
"input_credentials": {},
"name": "string (optional)"
}
```
**Response**: `GraphExecutionJobInfo`
**Behavior**:
- Creates APScheduler job with CronTrigger
- Uses job kwargs to store GraphExecutionJobArgs
- Sets `replace_existing=True` to allow updates
- Returns job info with generated ID and next run time
### 2. Delete Graph Execution Schedule
**Endpoint**: `/delete_graph_execution_schedule`
**Request Body**:
```json
{
"schedule_id": "string",
"user_id": "string"
}
```
**Response**: `GraphExecutionJobInfo`
**Behavior**:
- Validates schedule exists in EXECUTION jobstore
- Verifies user_id matches job's user_id
- Removes job from scheduler
- Returns deleted job info
**Errors**:
- `NotFoundError`: If job doesn't exist
- `NotAuthorizedError`: If user_id doesn't match
### 3. Get Graph Execution Schedules
**Endpoint**: `/get_graph_execution_schedules`
**Request Body**:
```json
{
"graph_id": "string (optional)",
"user_id": "string (optional)"
}
```
**Response**: `list[GraphExecutionJobInfo]`
**Behavior**:
- Retrieves all jobs from EXECUTION jobstore
- Filters by graph_id and/or user_id if provided
- Validates job kwargs as GraphExecutionJobArgs
- Skips invalid jobs (ValidationError)
- Only returns jobs with next_run_time set
### 4. System Task Endpoints
- `/execute_process_existing_batches` - Trigger batch processing
- `/execute_process_weekly_summary` - Trigger weekly summary
- `/execute_report_late_executions` - Trigger late execution report
- `/execute_report_block_error_rates` - Trigger error rate report
- `/execute_cleanup_expired_files` - Trigger file cleanup
### 5. Health Check
**Endpoints**: `/health_check`, `/health_check_async`
**Methods**: POST, GET
**Response**: "OK"
## Database Schema
### APScheduler Tables
The scheduler relies on APScheduler's SQLAlchemy jobstore schema:
1. **apscheduler_jobs**
- id: VARCHAR (PRIMARY KEY)
- next_run_time: FLOAT
- job_state: BLOB/BYTEA (pickled job data)
2. **apscheduler_jobs_batched_notifications**
- Same schema as above
- Separate table for notification jobs
### Database Configuration
- URL extraction from `DIRECT_URL` environment variable
- Schema extraction from URL query parameter
- Connection pooling: `pool_size=db_pool_size()`, `max_overflow=0`
- Metadata schema binding for multi-schema support
## External Dependencies
### Required Services
1. **PostgreSQL Database**
- Connection via `DIRECT_URL` environment variable
- Schema support via URL parameter
- APScheduler job persistence
2. **ExecutionManager** (via execution_utils)
- Function: `add_graph_execution`
- Called by: `execute_graph` job function
- Purpose: Create graph execution entries
3. **NotificationManager** (via monitoring module)
- Functions: `process_existing_batches`, `queue_weekly_summary`
- Purpose: Notification processing
4. **Cloud Storage** (via util.cloud_storage)
- Function: `cleanup_expired_files_async`
- Purpose: File expiration management
### Python Dependencies
```
apscheduler>=3.10.0
sqlalchemy
pydantic>=2.0
httpx
uvicorn
fastapi
python-dotenv
tenacity
```
## Authentication & Authorization
### Service-Level Authentication
- No authentication required between internal services
- Services communicate via trusted internal network
- Host/port configuration via environment variables
### User-Level Authorization
- Authorization check in `delete_graph_execution_schedule`:
- Validates `user_id` matches job's `user_id`
- Raises `NotAuthorizedError` on mismatch
- No authorization for read operations (security consideration)
## Process Management
### Startup Sequence
1. Load environment variables via `dotenv.load_dotenv()`
2. Extract database URL and schema
3. Initialize BlockingScheduler with configured jobstores
4. Register system tasks (if enabled)
5. Add job execution listener
6. Start scheduler (blocking)
### Shutdown Sequence
1. Receive SIGTERM/SIGINT signal
2. Call `cleanup()` method
3. Shutdown scheduler with `wait=False`
4. Terminate process
### Multi-Process Architecture
- Runs as independent process via `AppProcess`
- Started by `run_processes()` in app.py
- Can run in foreground or background mode
- Automatic signal handling for graceful shutdown
## Error Handling
### Job Execution Errors
- Listener on `EVENT_JOB_ERROR` logs failures
- Errors in job functions are caught and logged
- Jobs continue to run on schedule despite failures
### RPC Communication Errors
- Automatic retry via `@conn_retry` decorator
- Configurable retry count and timeout
- Connection pooling with self-healing
### Database Connection Errors
- APScheduler handles reconnection automatically
- Pool exhaustion prevented by `max_overflow=0`
- Connection errors logged but don't crash service
## Configuration
### Environment Variables
- `DIRECT_URL`: PostgreSQL connection string (required)
- `{SERVICE_NAME}_HOST`: Override service host
- Standard logging configuration
### Config Settings (via Config class)
```python
execution_scheduler_port: int = 8003
scheduler_db_pool_size: int = 3
execution_late_notification_threshold_secs: int
block_error_rate_check_interval_secs: int
cloud_storage_cleanup_interval_hours: int
pyro_host: str = "localhost"
pyro_client_comm_timeout: float = 15
pyro_client_comm_retry: int = 3
rpc_client_call_timeout: int = 300
```
## Testing Strategy
### Unit Tests
1. Mock APScheduler for job management tests
2. Mock database connections
3. Test each RPC endpoint independently
4. Verify job serialization/deserialization
### Integration Tests
1. Test with real PostgreSQL instance
2. Verify job persistence across restarts
3. Test concurrent job execution
4. Validate cron expression parsing
### Critical Test Cases
1. **Job Persistence**: Jobs survive scheduler restart
2. **User Isolation**: Users can only delete their own jobs
3. **Concurrent Access**: Multiple clients can add/remove jobs
4. **Error Recovery**: Service recovers from database outages
5. **Resource Cleanup**: No memory/connection leaks
## Implementation Notes
### Key Design Decisions
1. **BlockingScheduler vs AsyncIOScheduler**: Uses BlockingScheduler for simplicity and compatibility with multiprocessing architecture
2. **Job Storage**: All job arguments stored in kwargs, not in job name/id
3. **Separate Jobstores**: Isolation between execution and notification jobs
4. **No Authentication**: Relies on network isolation for security
### Migration Considerations
1. APScheduler job format must be preserved exactly
2. Database schema cannot change without migration
3. RPC protocol must maintain compatibility
4. Environment variables must match existing deployment
### Performance Considerations
1. Database pool size limited to prevent exhaustion
2. No job result storage (fire-and-forget pattern)
3. Minimal logging in hot paths
4. Connection reuse via pooling
## Appendix: Critical Implementation Details
### Event Loop Management
```python
@thread_cached
def get_event_loop():
return asyncio.new_event_loop()
def execute_graph(**kwargs):
get_event_loop().run_until_complete(_execute_graph(**kwargs))
```
### Job Function Execution Context
- Jobs run in scheduler's process space
- Each job gets fresh event loop
- No shared state between job executions
- Exceptions logged but don't affect scheduler
### Cron Expression Format
- Uses standard crontab format via `CronTrigger.from_crontab()`
- Supports: minute hour day month day_of_week
- Special strings: @yearly, @monthly, @weekly, @daily, @hourly
This specification provides all necessary details to reimplement the scheduler service while maintaining 100% compatibility with the existing system. Any deviation from these specifications may result in system incompatibility.

View File

@@ -0,0 +1,85 @@
name: CI
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
env:
CARGO_TERM_COLOR: always
RUSTFLAGS: "-D warnings"
jobs:
test:
name: Test
runs-on: ubuntu-latest
services:
redis:
image: redis:7
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 6379:6379
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- name: Run tests
run: cargo test
env:
REDIS_URL: redis://localhost:6379
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
components: clippy
- uses: Swatinem/rust-cache@v2
- name: Run clippy
run: |
cargo clippy -- \
-D warnings \
-D clippy::unwrap_used \
-D clippy::panic \
-D clippy::unimplemented \
-D clippy::todo
fmt:
name: Format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- name: Check formatting
run: cargo fmt -- --check
bench:
name: Benchmarks
runs-on: ubuntu-latest
services:
redis:
image: redis:7
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 6379:6379
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- name: Build benchmarks
run: cargo bench --no-run
env:
REDIS_URL: redis://localhost:6379

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,60 @@
[package]
name = "websocket"
authors = ["AutoGPT Team"]
description = "WebSocket server for AutoGPT Platform"
version = "0.1.0"
edition = "2021"
[lib]
name = "websocket"
path = "src/lib.rs"
[[bin]]
name = "websocket"
path = "src/main.rs"
[dependencies]
axum = { version = "0.7.5", features = ["ws"] }
jsonwebtoken = "9.3.0"
redis = { version = "0.25.4", features = ["aio", "tokio-comp"] }
serde = { version = "1.0.204", features = ["derive"] }
serde_json = "1.0.120"
tokio = { version = "1.38.1", features = ["rt-multi-thread", "macros", "net", "sync", "time", "io-util"] }
tower-http = { version = "0.5.2", features = ["cors"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
futures = "0.3"
dotenvy = "0.15"
clap = { version = "4.5.4", features = ["derive"] }
toml = "0.8"
[dev-dependencies]
# Load testing and profiling
tokio-console = "0.1"
criterion = { version = "0.5", features = ["async_tokio"] }
pprof = { version = "0.13", features = ["flamegraph", "criterion"] }
# Dependencies for benchmarks
tokio-tungstenite = "0.24"
futures-util = "0.3"
chrono = "0.4"
[[bench]]
name = "websocket_bench"
harness = false
[[example]]
name = "ws_client_example"
required-features = []
[profile.release]
opt-level = 3 # Maximum optimization
lto = true # Enable link-time optimization
codegen-units = 1 # Reduce parallel code generation units to increase optimization
panic = "abort" # Remove panic unwinding to reduce binary size
strip = true # Strip symbols from binary
[profile.bench]
opt-level = 3 # Maximum optimization
lto = true # Enable link-time optimization
codegen-units = 1 # Reduce parallel code generation units to increase optimization
debug = true # Keep debug symbols for profiling

View File

@@ -0,0 +1,412 @@
# WebSocket API Technical Specification
## Overview
This document provides a complete technical specification for the AutoGPT Platform WebSocket API (`ws_api.py`). The WebSocket API provides real-time updates for graph and node execution events, enabling clients to monitor workflow execution progress.
## Architecture Overview
### Core Components
1. **WebSocket Server** (`ws_api.py`)
- FastAPI application with WebSocket endpoint
- Handles client connections and message routing
- Authenticates clients via JWT tokens
- Manages subscriptions to execution events
2. **Connection Manager** (`conn_manager.py`)
- Maintains active WebSocket connections
- Manages channel subscriptions
- Routes execution events to subscribed clients
- Handles connection lifecycle
3. **Event Broadcasting System**
- Redis Pub/Sub based event bus
- Asynchronous event broadcaster
- Execution event propagation from backend services
## API Endpoint
### WebSocket Endpoint
- **URL**: `/ws`
- **Protocol**: WebSocket (ws:// or wss://)
- **Query Parameters**:
- `token` (required when auth enabled): JWT authentication token
## Authentication
### JWT Token Authentication
- **When Required**: When `settings.config.enable_auth` is `True`
- **Token Location**: Query parameter `?token=<JWT_TOKEN>`
- **Token Validation**:
```python
payload = parse_jwt_token(token)
user_id = payload.get("sub")
```
- **JWT Requirements**:
- Algorithm: Configured via `settings.JWT_ALGORITHM`
- Secret Key: Configured via `settings.JWT_SECRET_KEY`
- Audience: Must be "authenticated"
- Claims: Must contain `sub` (user ID)
### Authentication Failures
- **4001**: Missing authentication token
- **4002**: Invalid token (missing user ID)
- **4003**: Invalid token (parsing error or expired)
### No-Auth Mode
- When `settings.config.enable_auth` is `False`
- Uses `DEFAULT_USER_ID` from `backend.data.user`
## Message Protocol
### Message Format
All messages use JSON format with the following structure:
```typescript
interface WSMessage {
method: WSMethod;
data?: Record<string, any> | any[] | string;
success?: boolean;
channel?: string;
error?: string;
}
```
### Message Methods (WSMethod enum)
1. **Client-to-Server Methods**:
- `SUBSCRIBE_GRAPH_EXEC`: Subscribe to specific graph execution
- `SUBSCRIBE_GRAPH_EXECS`: Subscribe to all executions of a graph
- `UNSUBSCRIBE`: Unsubscribe from a channel
- `HEARTBEAT`: Keep-alive ping
2. **Server-to-Client Methods**:
- `GRAPH_EXECUTION_EVENT`: Graph execution status update
- `NODE_EXECUTION_EVENT`: Node execution status update
- `ERROR`: Error message
- `HEARTBEAT`: Keep-alive pong
## Subscription Models
### Subscribe to Specific Graph Execution
```typescript
interface WSSubscribeGraphExecutionRequest {
graph_exec_id: string;
}
```
**Channel Key Format**: `{user_id}|graph_exec#{graph_exec_id}`
### Subscribe to All Graph Executions
```typescript
interface WSSubscribeGraphExecutionsRequest {
graph_id: string;
}
```
**Channel Key Format**: `{user_id}|graph#{graph_id}|executions`
## Event Models
### Graph Execution Event
```typescript
interface GraphExecutionEvent {
event_type: "graph_execution_update";
id: string; // graph_exec_id
user_id: string;
graph_id: string;
graph_version: number;
preset_id?: string;
status: ExecutionStatus;
started_at: string; // ISO datetime
ended_at: string; // ISO datetime
inputs: Record<string, any>;
outputs: Record<string, any>;
stats?: {
cost: number; // cents
duration: number; // seconds
duration_cpu_only: number;
node_exec_time: number;
node_exec_time_cpu_only: number;
node_exec_count: number;
node_error_count: number;
error?: string;
};
}
```
### Node Execution Event
```typescript
interface NodeExecutionEvent {
event_type: "node_execution_update";
user_id: string;
graph_id: string;
graph_version: number;
graph_exec_id: string;
node_exec_id: string;
node_id: string;
block_id: string;
status: ExecutionStatus;
input_data: Record<string, any>;
output_data: Record<string, any>;
add_time: string; // ISO datetime
queue_time?: string; // ISO datetime
start_time?: string; // ISO datetime
end_time?: string; // ISO datetime
}
```
### Execution Status Enum
```typescript
enum ExecutionStatus {
INCOMPLETE = "INCOMPLETE",
QUEUED = "QUEUED",
RUNNING = "RUNNING",
COMPLETED = "COMPLETED",
FAILED = "FAILED"
}
```
## Message Flow Examples
### 1. Subscribe to Graph Execution
```json
// Client → Server
{
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": "exec-123"
}
}
// Server → Client (Success)
{
"method": "subscribe_graph_execution",
"success": true,
"channel": "user-456|graph_exec#exec-123"
}
```
### 2. Receive Execution Updates
```json
// Server → Client (Graph Update)
{
"method": "graph_execution_event",
"channel": "user-456|graph_exec#exec-123",
"data": {
"event_type": "graph_execution_update",
"id": "exec-123",
"user_id": "user-456",
"graph_id": "graph-789",
"status": "RUNNING",
// ... other fields
}
}
// Server → Client (Node Update)
{
"method": "node_execution_event",
"channel": "user-456|graph_exec#exec-123",
"data": {
"event_type": "node_execution_update",
"node_exec_id": "node-exec-111",
"status": "COMPLETED",
// ... other fields
}
}
```
### 3. Heartbeat
```json
// Client → Server
{
"method": "heartbeat",
"data": "ping"
}
// Server → Client
{
"method": "heartbeat",
"data": "pong",
"success": true
}
```
### 4. Error Handling
```json
// Server → Client (Invalid Message)
{
"method": "error",
"success": false,
"error": "Invalid message format. Review the schema and retry"
}
```
## Event Broadcasting Architecture
### Redis Pub/Sub Integration
1. **Event Bus Name**: Configured via `config.execution_event_bus_name`
2. **Channel Pattern**: `{event_bus_name}/{channel_key}`
3. **Event Flow**:
- Execution services publish events to Redis
- Event broadcaster listens to Redis pattern `*`
- Events are routed to WebSocket connections based on subscriptions
### Event Broadcaster
- Runs as continuous async task using `@continuous_retry()` decorator
- Listens to all execution events via `AsyncRedisExecutionEventBus`
- Calls `ConnectionManager.send_execution_update()` for each event
## Connection Lifecycle
### Connection Establishment
1. Client connects to `/ws` endpoint
2. Authentication performed (JWT validation)
3. WebSocket accepted via `manager.connect_socket()`
4. Connection added to active connections set
### Message Processing Loop
1. Receive text message from client
2. Parse and validate as `WSMessage`
3. Route to appropriate handler based on `method`
4. Send response or error back to client
### Connection Termination
1. `WebSocketDisconnect` exception caught
2. `manager.disconnect_socket()` called
3. Connection removed from active connections
4. All subscriptions for that connection removed
## Error Handling
### Validation Errors
- **Invalid Message Format**: Returns error with method "error"
- **Invalid Message Data**: Returns error with specific validation message
- **Unknown Message Type**: Returns error indicating unsupported method
### Connection Errors
- WebSocket disconnections handled gracefully
- Failed event parsing logged but doesn't crash connection
- Handler exceptions logged and connection continues
## Configuration
### Environment Variables
```python
# WebSocket Server Configuration
websocket_server_host: str = "0.0.0.0"
websocket_server_port: int = 8001
# Authentication
enable_auth: bool = True
# CORS
backend_cors_allow_origins: List[str] = []
# Redis Event Bus
execution_event_bus_name: str = "autogpt:execution_event_bus"
# Message Size Limits
max_message_size_limit: int = 512000 # 512KB
```
### Security Headers
- CORS middleware applied with configured origins
- Credentials allowed for authenticated requests
- All methods and headers allowed (configurable)
## Deployment Requirements
### Dependencies
1. **FastAPI**: Web framework with WebSocket support
2. **Redis**: For pub/sub event broadcasting
3. **JWT Libraries**: For token validation
4. **Prisma**: Database ORM (for future graph access validation)
### Process Management
- Implements `AppProcess` interface for service lifecycle
- Runs via `uvicorn` ASGI server
- Graceful shutdown handling in `cleanup()` method
### Concurrent Connections
- No hard limit on WebSocket connections
- Memory usage scales with active connections
- Each connection maintains subscription set
## Implementation Checklist
To implement a compatible WebSocket API:
1. **Authentication**
- [ ] JWT token validation from query parameters
- [ ] Support for no-auth mode with default user ID
- [ ] Proper error codes for auth failures
2. **Message Handling**
- [ ] Parse and validate WSMessage format
- [ ] Implement all client-to-server methods
- [ ] Support all server-to-client event types
- [ ] Proper error responses for invalid messages
3. **Subscription Management**
- [ ] Channel key generation matching exact format
- [ ] Support for both execution and graph-level subscriptions
- [ ] Unsubscribe functionality
- [ ] Clean up subscriptions on disconnect
4. **Event Broadcasting**
- [ ] Listen to Redis pub/sub for execution events
- [ ] Route events to correct subscribed connections
- [ ] Handle both graph and node execution events
- [ ] Maintain event order and completeness
5. **Connection Management**
- [ ] Track active WebSocket connections
- [ ] Handle graceful disconnections
- [ ] Implement heartbeat/keepalive
- [ ] Memory-efficient subscription storage
6. **Configuration**
- [ ] Support all environment variables
- [ ] CORS configuration for allowed origins
- [ ] Configurable host/port binding
- [ ] Redis connection configuration
7. **Error Handling**
- [ ] Graceful handling of malformed messages
- [ ] Logging of errors without dropping connections
- [ ] Specific error messages for debugging
- [ ] Recovery from Redis connection issues
## Testing Considerations
1. **Unit Tests**
- Message parsing and validation
- Channel key generation
- Subscription management logic
2. **Integration Tests**
- Full WebSocket connection flow
- Event broadcasting from Redis
- Multi-client subscription scenarios
- Authentication success/failure cases
3. **Load Tests**
- Many concurrent connections
- High-frequency event broadcasting
- Memory usage under load
- Connection/disconnection cycles
## Security Considerations
1. **Authentication**: JWT tokens transmitted via query parameters (consider upgrading to headers)
2. **Authorization**: Currently no graph-level access validation (commented out in code)
3. **Rate Limiting**: No rate limiting implemented
4. **Message Size**: Limited by `max_message_size_limit` configuration
5. **Input Validation**: All inputs validated via Pydantic models
## Future Enhancements (Currently Commented Out)
1. **Graph Access Validation**: Verify user has read access to subscribed graphs
2. **Message Compression**: For large execution payloads
3. **Batch Updates**: Aggregate multiple events in single message
4. **Selective Field Subscription**: Subscribe to specific fields only

View File

@@ -0,0 +1,93 @@
# WebSocket Server Benchmarks
This directory contains performance benchmarks for the AutoGPT WebSocket server.
## Prerequisites
1. Redis must be running locally or set `REDIS_URL` environment variable:
```bash
docker run -d -p 6379:6379 redis:latest
```
2. Build the project in release mode:
```bash
cargo build --release
```
## Running Benchmarks
Run all benchmarks:
```bash
cargo bench
```
Run specific benchmark group:
```bash
cargo bench connection_establishment
cargo bench subscriptions
cargo bench message_throughput
cargo bench concurrent_connections
cargo bench message_parsing
cargo bench redis_event_processing
```
## Benchmark Categories
### Connection Establishment
Tests the performance of establishing WebSocket connections with different authentication scenarios:
- No authentication
- Valid JWT authentication
- Invalid JWT authentication (connection rejection)
### Subscriptions
Measures the performance of subscription operations:
- Subscribing to graph execution events
- Unsubscribing from channels
### Message Throughput
Tests how many messages the server can process per second with varying message counts (10, 100, 1000).
### Concurrent Connections
Benchmarks the server's ability to handle multiple simultaneous connections (10, 50, 100, 500 clients).
### Message Parsing
Tests JSON parsing performance with different message sizes (100B to 100KB).
### Redis Event Processing
Benchmarks the parsing of execution events received from Redis.
## Profiling
To generate flamegraphs for CPU profiling:
1. Install flamegraph tools:
```bash
cargo install flamegraph
```
2. Run benchmarks with profiling:
```bash
cargo bench --bench websocket_bench -- --profile-time=10
```
## Interpreting Results
- **Throughput**: Higher is better (operations/second or elements/second)
- **Time**: Lower is better (nanoseconds per operation)
- **Error margins**: Look for stable results with low standard deviation
## Optimizing Performance
Based on benchmark results, consider:
1. **Connection pooling** for Redis connections
2. **Message batching** for high-throughput scenarios
3. **Async task tuning** for concurrent connection handling
4. **JSON parsing optimization** using simd-json or other fast parsers
5. **Memory allocation** optimization using arena allocators
## Notes
- Benchmarks create actual WebSocket servers on random ports
- Each benchmark iteration properly cleans up resources
- Results may vary based on system resources and Redis performance

View File

@@ -0,0 +1,406 @@
#![allow(clippy::unwrap_used)] // Benchmarks can panic on setup errors
use axum::{routing::get, Router};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use futures_util::{SinkExt, StreamExt};
use serde_json::json;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::runtime::Runtime;
use tokio_tungstenite::{connect_async, tungstenite::Message};
// Import the actual websocket server components
use websocket::{models, ws_handler, AppState, Config, ConnectionManager, Stats};
// Helper to create a test server
async fn create_test_server(enable_auth: bool) -> (String, tokio::task::JoinHandle<()>) {
// Set environment variables for test config
std::env::set_var("WEBSOCKET_SERVER_HOST", "127.0.0.1");
std::env::set_var("WEBSOCKET_SERVER_PORT", "0");
std::env::set_var("ENABLE_AUTH", enable_auth.to_string());
std::env::set_var("SUPABASE_JWT_SECRET", "test_secret");
std::env::set_var("DEFAULT_USER_ID", "test_user");
if std::env::var("REDIS_URL").is_err() {
std::env::set_var("REDIS_URL", "redis://localhost:6379");
}
let mut config = Config::load(None);
config.port = 0; // Force OS to assign port
let redis_client =
redis::Client::open(config.redis_url.clone()).expect("Failed to connect to Redis");
let stats = Arc::new(Stats::default());
let mgr = Arc::new(ConnectionManager::new(
redis_client,
config.execution_event_bus_name.clone(),
stats.clone(),
));
// Start broadcaster
let mgr_clone = mgr.clone();
tokio::spawn(async move {
mgr_clone.run_broadcaster().await;
});
let state = AppState {
mgr,
config: Arc::new(config),
stats,
};
let app = Router::new()
.route("/ws", get(ws_handler))
.layer(axum::Extension(state));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_url = format!("ws://{addr}");
let server_handle = tokio::spawn(async move {
axum::serve(listener, app.into_make_service())
.await
.unwrap();
});
// Give server time to start
tokio::time::sleep(Duration::from_millis(100)).await;
(server_url, server_handle)
}
// Helper to create a valid JWT token
fn create_jwt_token(user_id: &str) -> String {
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use serde::Serialize;
#[derive(Serialize)]
struct Claims {
sub: String,
aud: Vec<String>,
exp: usize,
}
let claims = Claims {
sub: user_id.to_string(),
aud: vec!["authenticated".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp() as usize,
};
encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(b"test_secret"),
)
.unwrap()
}
// Benchmark connection establishment
fn benchmark_connection_establishment(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("connection_establishment");
group.measurement_time(Duration::from_secs(30));
// Test without auth
group.bench_function("no_auth", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
let (ws_stream, _) = connect_async(&url).await.unwrap();
drop(ws_stream);
server_handle.abort();
});
});
// Test with valid auth
group.bench_function("valid_auth", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(true).await;
let token = create_jwt_token("test_user");
let url = format!("{server_url}/ws?token={token}");
let (ws_stream, _) = connect_async(&url).await.unwrap();
drop(ws_stream);
server_handle.abort();
});
});
// Test with invalid auth
group.bench_function("invalid_auth", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(true).await;
let url = format!("{server_url}/ws?token=invalid");
let result = connect_async(&url).await;
assert!(
result.is_err() || {
if let Ok((mut ws_stream, _)) = result {
// Should receive close frame
matches!(ws_stream.next().await, Some(Ok(Message::Close(_))))
} else {
false
}
}
);
server_handle.abort();
});
});
group.finish();
}
// Benchmark subscription operations
fn benchmark_subscriptions(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("subscriptions");
group.measurement_time(Duration::from_secs(20));
group.bench_function("subscribe_graph_execution", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
let msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": "test_exec_123"
}
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
// Wait for response
if let Some(Ok(Message::Text(response))) = ws_stream.next().await {
let resp: serde_json::Value = serde_json::from_str(&response).unwrap();
assert_eq!(resp["success"], true);
}
server_handle.abort();
});
});
group.bench_function("unsubscribe", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
// First subscribe
let msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": "test_exec_123"
}
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
ws_stream.next().await; // Consume response
let msg = json!({
"method": "unsubscribe",
"data": {
"channel": "test_user|graph_exec#test_exec_123"
}
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
// Wait for response
if let Some(Ok(Message::Text(response))) = ws_stream.next().await {
let resp: serde_json::Value = serde_json::from_str(&response).unwrap();
assert_eq!(resp["success"], true);
}
server_handle.abort();
});
});
group.finish();
}
// Benchmark message throughput
fn benchmark_message_throughput(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("message_throughput");
group.measurement_time(Duration::from_secs(30));
for msg_count in [10, 100, 1000].iter() {
group.throughput(Throughput::Elements(*msg_count as u64));
group.bench_with_input(
BenchmarkId::from_parameter(msg_count),
msg_count,
|b, &msg_count| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
// Send multiple heartbeat messages
for _ in 0..msg_count {
let msg = json!({
"method": "heartbeat",
"data": "ping"
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
}
// Receive all responses
for _ in 0..msg_count {
ws_stream.next().await;
}
server_handle.abort();
});
},
);
}
group.finish();
}
// Benchmark concurrent connections
fn benchmark_concurrent_connections(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("concurrent_connections");
group.measurement_time(Duration::from_secs(60));
group.sample_size(10);
for num_clients in [100, 500, 1000].iter() {
group.throughput(Throughput::Elements(*num_clients as u64));
group.bench_with_input(
BenchmarkId::from_parameter(num_clients),
num_clients,
|b, &num_clients| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
// Create multiple concurrent connections
let mut handles = vec![];
for i in 0..num_clients {
let url = url.clone();
let handle = tokio::spawn(async move {
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
// Subscribe to a unique channel
let msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": format!("exec_{}", i)
}
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
ws_stream.next().await; // Wait for response
// Send a heartbeat
let msg = json!({
"method": "heartbeat",
"data": "ping"
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
ws_stream.next().await; // Wait for response
ws_stream
});
handles.push(handle);
}
// Wait for all connections to complete
for handle in handles {
let _ = handle.await;
}
server_handle.abort();
});
},
);
}
group.finish();
}
// Benchmark message parsing
fn benchmark_message_parsing(c: &mut Criterion) {
let mut group = c.benchmark_group("message_parsing");
// Test different message sizes
for msg_size in [100, 1000, 10000].iter() {
group.throughput(Throughput::Bytes(*msg_size as u64));
group.bench_with_input(
BenchmarkId::new("parse_json", msg_size),
msg_size,
|b, &msg_size| {
let data_str = "x".repeat(msg_size);
let json_msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": data_str
}
});
let json_str = json_msg.to_string();
b.iter(|| {
let _: models::WSMessage = serde_json::from_str(&json_str).unwrap();
});
},
);
}
group.finish();
}
// Benchmark Redis event processing
fn benchmark_redis_event_processing(c: &mut Criterion) {
let mut group = c.benchmark_group("redis_event_processing");
group.bench_function("parse_execution_event", |b| {
let event = json!({
"payload": {
"event_type": "graph_execution_update",
"id": "exec_123",
"graph_id": "graph_456",
"graph_version": 1,
"user_id": "user_789",
"status": "RUNNING",
"started_at": "2024-01-01T00:00:00Z",
"inputs": {"test": "data"},
"outputs": {}
}
});
let event_str = event.to_string();
b.iter(|| {
let _: models::RedisEventWrapper = serde_json::from_str(&event_str).unwrap();
});
});
group.finish();
}
criterion_group!(
benches,
benchmark_connection_establishment,
benchmark_subscriptions,
benchmark_message_throughput,
benchmark_concurrent_connections,
benchmark_message_parsing,
benchmark_redis_event_processing
);
criterion_main!(benches);

View File

@@ -0,0 +1,10 @@
# Clippy configuration for robust error handling
# Set the maximum cognitive complexity allowed
cognitive-complexity-threshold = 30
# Warn on TODO/FIXME comments
allow-dbg-in-tests = false
# Enforce documentation
missing-docs-in-crate-items = true

View File

@@ -0,0 +1,23 @@
# WebSocket API Configuration
# Server settings
host = "0.0.0.0"
port = 8001
# Authentication
enable_auth = true
jwt_secret = "your-super-secret-jwt-token-with-at-least-32-characters-long"
jwt_algorithm = "HS256"
default_user_id = "default"
# Redis configuration
redis_url = "redis://:password@localhost:6379/"
# Event bus
execution_event_bus_name = "execution_event"
# Message size limit (in bytes)
max_message_size_limit = 512000
# CORS allowed origins
backend_cors_allow_origins = ["http://localhost:3000", "https://559f69c159ef.ngrok.app"]

View File

@@ -0,0 +1,75 @@
use futures_util::{SinkExt, StreamExt};
use serde_json::json;
use tokio_tungstenite::{connect_async, tungstenite::Message};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let url = "ws://localhost:8001/ws";
println!("Connecting to {url}");
let (mut ws_stream, _) = connect_async(url).await?;
println!("Connected!");
// Subscribe to a graph execution
let subscribe_msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": "test_exec_123"
}
});
println!("Sending subscription request...");
ws_stream
.send(Message::Text(subscribe_msg.to_string()))
.await?;
// Wait for response
if let Some(msg) = ws_stream.next().await {
if let Message::Text(text) = msg? {
println!("Received: {text}");
}
}
// Send heartbeat
let heartbeat_msg = json!({
"method": "heartbeat",
"data": "ping"
});
println!("Sending heartbeat...");
ws_stream
.send(Message::Text(heartbeat_msg.to_string()))
.await?;
// Wait for pong
if let Some(msg) = ws_stream.next().await {
if let Message::Text(text) = msg? {
println!("Received: {text}");
}
}
// Unsubscribe
let unsubscribe_msg = json!({
"method": "unsubscribe",
"data": {
"channel": "default|graph_exec#test_exec_123"
}
});
println!("Sending unsubscribe request...");
ws_stream
.send(Message::Text(unsubscribe_msg.to_string()))
.await?;
// Wait for response
if let Some(msg) = ws_stream.next().await {
if let Message::Text(text) = msg? {
println!("Received: {text}");
}
}
println!("Closing connection...");
ws_stream.close(None).await?;
Ok(())
}

View File

@@ -0,0 +1,99 @@
use jsonwebtoken::Algorithm;
use serde::Deserialize;
use std::env;
use std::fs;
use std::path::Path;
use std::str::FromStr;
#[derive(Clone, Debug, Deserialize)]
pub struct Config {
pub host: String,
pub port: u16,
pub enable_auth: bool,
pub jwt_secret: String,
pub jwt_algorithm: Algorithm,
pub execution_event_bus_name: String,
pub redis_url: String,
pub default_user_id: String,
pub max_message_size_limit: usize,
pub backend_cors_allow_origins: Vec<String>,
}
impl Config {
pub fn load(config_path: Option<&Path>) -> Self {
let path = config_path.unwrap_or(Path::new("config.toml"));
let toml_result = fs::read_to_string(path)
.ok()
.and_then(|s| toml::from_str::<Config>(&s).ok());
let mut config = match toml_result {
Some(config) => config,
None => Config {
host: env::var("WEBSOCKET_SERVER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()),
port: env::var("WEBSOCKET_SERVER_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(8001),
enable_auth: env::var("ENABLE_AUTH")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(true),
jwt_secret: env::var("SUPABASE_JWT_SECRET")
.unwrap_or_else(|_| "dummy_secret_for_no_auth".to_string()),
jwt_algorithm: Algorithm::HS256,
execution_event_bus_name: env::var("EXECUTION_EVENT_BUS_NAME")
.unwrap_or_else(|_| "execution_event".to_string()),
redis_url: env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://localhost/".to_string()),
default_user_id: "default".to_string(),
max_message_size_limit: env::var("MAX_MESSAGE_SIZE_LIMIT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(512000),
backend_cors_allow_origins: env::var("BACKEND_CORS_ALLOW_ORIGINS")
.unwrap_or_default()
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect(),
},
};
if let Ok(v) = env::var("WEBSOCKET_SERVER_HOST") {
config.host = v;
}
if let Ok(v) = env::var("WEBSOCKET_SERVER_PORT") {
config.port = v.parse().unwrap_or(8001);
}
if let Ok(v) = env::var("ENABLE_AUTH") {
config.enable_auth = v.parse().unwrap_or(true);
}
if let Ok(v) = env::var("SUPABASE_JWT_SECRET") {
config.jwt_secret = v;
}
if let Ok(v) = env::var("JWT_ALGORITHM") {
config.jwt_algorithm = Algorithm::from_str(&v).unwrap_or(Algorithm::HS256);
}
if let Ok(v) = env::var("EXECUTION_EVENT_BUS_NAME") {
config.execution_event_bus_name = v;
}
if let Ok(v) = env::var("REDIS_URL") {
config.redis_url = v;
}
if let Ok(v) = env::var("DEFAULT_USER_ID") {
config.default_user_id = v;
}
if let Ok(v) = env::var("MAX_MESSAGE_SIZE_LIMIT") {
config.max_message_size_limit = v.parse().unwrap_or(512000);
}
if let Ok(v) = env::var("BACKEND_CORS_ALLOW_ORIGINS") {
config.backend_cors_allow_origins = v
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
}
config
}
}

View File

@@ -0,0 +1,277 @@
use futures::StreamExt;
use redis::Client as RedisClient;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, error, info, warn};
use crate::models::{ExecutionEvent, RedisEventWrapper, WSMessage};
use crate::stats::Stats;
pub struct ConnectionManager {
pub subscribers: RwLock<HashMap<String, HashSet<u64>>>,
pub clients: RwLock<HashMap<u64, (String, mpsc::Sender<String>)>>,
pub client_channels: RwLock<HashMap<u64, HashSet<String>>>,
pub next_id: AtomicU64,
pub redis_client: RedisClient,
pub bus_name: String,
pub stats: Arc<Stats>,
}
impl ConnectionManager {
pub fn new(redis_client: RedisClient, bus_name: String, stats: Arc<Stats>) -> Self {
Self {
subscribers: RwLock::new(HashMap::new()),
clients: RwLock::new(HashMap::new()),
client_channels: RwLock::new(HashMap::new()),
next_id: AtomicU64::new(0),
redis_client,
bus_name,
stats,
}
}
pub async fn run_broadcaster(self: Arc<Self>) {
info!("🚀 Starting Redis event broadcaster");
loop {
match self.run_broadcaster_inner().await {
Ok(_) => {
warn!("⚠️ Event broadcaster stopped unexpectedly, restarting in 5 seconds");
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
Err(e) => {
error!("❌ Event broadcaster error: {}, restarting in 5 seconds", e);
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
}
}
}
async fn run_broadcaster_inner(
self: &Arc<Self>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut pubsub = self.redis_client.get_async_pubsub().await?;
pubsub.psubscribe("*").await?;
info!(
"📡 Listening to all Redis events, filtering for bus: {}",
self.bus_name
);
let mut pubsub_stream = pubsub.on_message();
loop {
let msg = pubsub_stream.next().await;
match msg {
Some(msg) => {
let channel: String = msg.get_channel_name().to_string();
debug!("📨 Received message on Redis channel: {}", channel);
self.stats
.redis_messages_received
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let payload: String = match msg.get_payload() {
Ok(p) => p,
Err(e) => {
warn!("⚠️ Failed to get payload from Redis message: {}", e);
self.stats
.errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
continue;
}
};
// Parse the channel format: execution_event/{user_id}/{graph_id}/{graph_exec_id}
let parts: Vec<&str> = channel.split('/').collect();
// Check if this is an execution event channel
if parts.len() != 4 || parts[0] != self.bus_name {
debug!(
"🚫 Ignoring non-execution event channel: {} (parts: {:?}, bus_name: {})",
channel, parts, self.bus_name
);
self.stats
.redis_messages_ignored
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
continue;
}
let user_id = parts[1];
let graph_id = parts[2];
let graph_exec_id = parts[3];
debug!(
"📥 Received event - user: {}, graph: {}, exec: {}",
user_id, graph_id, graph_exec_id
);
// Parse the wrapped event
let wrapped_event = match RedisEventWrapper::parse(&payload) {
Ok(e) => e,
Err(e) => {
warn!("⚠️ Failed to parse event JSON: {}, payload: {}", e, payload);
self.stats
.errors_json_parse
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.stats
.errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
continue;
}
};
let event = wrapped_event.payload;
debug!("📦 Event received: {:?}", event);
let (method, event_json) = match &event {
ExecutionEvent::GraphExecutionUpdate(graph_event) => {
self.stats
.graph_execution_events
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.stats
.events_received_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
(
"graph_execution_event",
match serde_json::to_value(graph_event) {
Ok(v) => v,
Err(e) => {
error!("❌ Failed to serialize graph event: {}", e);
continue;
}
},
)
}
ExecutionEvent::NodeExecutionUpdate(node_event) => {
self.stats
.node_execution_events
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.stats
.events_received_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
(
"node_execution_event",
match serde_json::to_value(node_event) {
Ok(v) => v,
Err(e) => {
error!("❌ Failed to serialize node event: {}", e);
continue;
}
},
)
}
};
// Create the channel keys in the format expected by WebSocket clients
let mut channels_to_notify = Vec::new();
// For both event types, notify the specific execution channel
let exec_channel = format!("{user_id}|graph_exec#{graph_exec_id}");
channels_to_notify.push(exec_channel.clone());
// For graph execution events, also notify the graph executions channel
if matches!(&event, ExecutionEvent::GraphExecutionUpdate(_)) {
let graph_channel = format!("{user_id}|graph#{graph_id}|executions");
channels_to_notify.push(graph_channel);
}
debug!(
"📢 Broadcasting {} event to channels: {:?}",
method, channels_to_notify
);
let subs = self.subscribers.read().await;
// Log current subscriber state
debug!("📊 Current subscribers count: {}", subs.len());
for channel_key in channels_to_notify {
let ws_msg = WSMessage {
method: method.to_string(),
channel: Some(channel_key.clone()),
data: Some(event_json.clone()),
..Default::default()
};
let json_msg = match serde_json::to_string(&ws_msg) {
Ok(j) => {
debug!("📤 Sending WebSocket message: {}", j);
j
}
Err(e) => {
error!("❌ Failed to serialize WebSocket message: {}", e);
continue;
}
};
if let Some(client_ids) = subs.get(&channel_key) {
let clients = self.clients.read().await;
let client_count = client_ids.len();
debug!(
"📣 Broadcasting to {} clients on channel: {}",
client_count, channel_key
);
for &cid in client_ids {
if let Some((user_id, tx)) = clients.get(&cid) {
match tx.try_send(json_msg.clone()) {
Ok(_) => {
debug!(
"✅ Message sent immediately to client {} (user: {})",
cid, user_id
);
self.stats
.messages_sent_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
Err(mpsc::error::TrySendError::Full(_)) => {
// Channel is full, try with a small timeout
let tx_clone = tx.clone();
let msg_clone = json_msg.clone();
let stats_clone = self.stats.clone();
tokio::spawn(async move {
match tokio::time::timeout(
std::time::Duration::from_millis(100),
tx_clone.send(msg_clone),
)
.await {
Ok(Ok(_)) => {
stats_clone
.messages_sent_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
_ => {
stats_clone
.messages_failed_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
});
warn!("⚠️ Channel full for client {} (user: {}), sending async", cid, user_id);
}
Err(mpsc::error::TrySendError::Closed(_)) => {
warn!(
"⚠️ Channel closed for client {} (user: {})",
cid, user_id
);
self.stats
.messages_failed_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
} else {
warn!("⚠️ Client {} not found in clients map", cid);
}
}
} else {
info!("📭 No subscribers for channel: {}", channel_key);
}
}
}
None => {
return Err("❌ Redis pubsub stream ended".into());
}
}
}
}
}

View File

@@ -0,0 +1,442 @@
use axum::extract::ws::{CloseFrame, Message, WebSocket};
use axum::{
extract::{Query, WebSocketUpgrade},
http::HeaderMap,
response::IntoResponse,
Extension,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde_json::{json, Value};
use std::collections::HashMap;
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::connection_manager::ConnectionManager;
use crate::models::{Claims, WSMessage};
use crate::AppState;
// Helper function to safely serialize messages
fn serialize_message(msg: &WSMessage) -> String {
serde_json::to_string(msg).unwrap_or_else(|e| {
error!("❌ Failed to serialize WebSocket message: {}", e);
json!({"method": "error", "success": false, "error": "Internal serialization error"})
.to_string()
})
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
query: Query<HashMap<String, String>>,
_headers: HeaderMap,
Extension(state): Extension<AppState>,
) -> impl IntoResponse {
let token = query.0.get("token").cloned();
let mut user_id = state.config.default_user_id.clone();
let mut auth_error_code: Option<u16> = None;
if state.config.enable_auth {
match token {
Some(token_str) => {
debug!("🔐 Authenticating WebSocket connection");
let mut validation = Validation::new(state.config.jwt_algorithm);
validation.set_audience(&["authenticated"]);
let key = DecodingKey::from_secret(state.config.jwt_secret.as_bytes());
match decode::<Claims>(&token_str, &key, &validation) {
Ok(token_data) => {
user_id = token_data.claims.sub.clone();
debug!("✅ WebSocket authenticated for user: {}", user_id);
}
Err(e) => {
warn!("⚠️ JWT validation failed: {}", e);
auth_error_code = Some(4003);
}
}
}
None => {
warn!("⚠️ Missing authentication token in WebSocket connection");
auth_error_code = Some(4001);
}
}
} else {
debug!("🔓 WebSocket connection without auth (auth disabled)");
}
if let Some(code) = auth_error_code {
error!("❌ WebSocket authentication failed with code: {}", code);
state
.mgr
.stats
.connections_failed_auth
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
state
.mgr
.stats
.connections_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
return ws
.on_upgrade(move |mut socket: WebSocket| async move {
let close_frame = Some(CloseFrame {
code,
reason: "Authentication failed".into(),
});
let _ = socket.send(Message::Close(close_frame)).await;
let _ = socket.close().await;
})
.into_response();
}
debug!("✅ WebSocket connection established for user: {}", user_id);
ws.on_upgrade(move |socket| {
handle_socket(
socket,
user_id,
state.mgr.clone(),
state.config.max_message_size_limit,
)
})
}
async fn update_subscription_stats(mgr: &ConnectionManager, channel: &str, add: bool) {
if add {
mgr.stats
.subscriptions_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
mgr.stats
.subscriptions_active
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let mut channel_stats = mgr.stats.channels_active.write().await;
let count = channel_stats.entry(channel.to_string()).or_insert(0);
*count += 1;
} else {
mgr.stats
.unsubscriptions_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
mgr.stats
.subscriptions_active
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
let mut channel_stats = mgr.stats.channels_active.write().await;
if let Some(count) = channel_stats.get_mut(channel) {
*count = count.saturating_sub(1);
if *count == 0 {
channel_stats.remove(channel);
}
}
}
}
pub async fn handle_socket(
mut socket: WebSocket,
user_id: String,
mgr: std::sync::Arc<ConnectionManager>,
max_size: usize,
) {
let client_id = mgr
.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let (tx, mut rx) = mpsc::channel::<String>(10);
info!("👋 New WebSocket client {} for user: {}", client_id, user_id);
// Update connection stats
mgr.stats
.connections_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
mgr.stats
.connections_active
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
// Update active users
{
let mut active_users = mgr.stats.active_users.write().await;
let count = active_users.entry(user_id.clone()).or_insert(0);
*count += 1;
}
{
let mut clients = mgr.clients.write().await;
clients.insert(client_id, (user_id.clone(), tx));
}
{
let mut client_channels = mgr.client_channels.write().await;
client_channels.insert(client_id, std::collections::HashSet::new());
}
loop {
tokio::select! {
msg = rx.recv() => {
if let Some(msg) = msg {
if socket.send(Message::Text(msg)).await.is_err() {
break;
}
} else {
break;
}
}
incoming = socket.recv() => {
let msg = match incoming {
Some(Ok(msg)) => msg,
_ => break,
};
match msg {
Message::Text(text) => {
if text.len() > max_size {
warn!("⚠️ Message from client {} exceeds size limit: {} > {}", client_id, text.len(), max_size);
let err_resp = serialize_message(&WSMessage {
method: "error".to_string(),
success: Some(false),
error: Some("Message exceeds size limit".to_string()),
..Default::default()
});
if socket.send(Message::Text(err_resp)).await.is_err() {
break;
}
continue;
}
mgr.stats.messages_received_total.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let ws_msg: WSMessage = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
warn!("⚠️ Invalid message format from client {}: {}", client_id, e);
mgr.stats.errors_json_parse.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
mgr.stats.errors_total.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let err_resp = serialize_message(&WSMessage {
method: "error".to_string(),
success: Some(false),
error: Some("Invalid message format. Review the schema and retry".to_string()),
..Default::default()
});
if socket.send(Message::Text(err_resp)).await.is_err() {
break;
}
continue;
}
};
debug!("📥 Received {} message from client {}", ws_msg.method, client_id);
match ws_msg.method.as_str() {
"subscribe_graph_execution" => {
let graph_exec_id = match &ws_msg.data {
Some(Value::Object(map)) => map.get("graph_exec_id").and_then(|v| v.as_str()),
_ => None,
};
let Some(graph_exec_id) = graph_exec_id else {
warn!("⚠️ Missing graph_exec_id in subscribe_graph_execution from client {}", client_id);
let err_resp = json!({"method": "error", "success": false, "error": "Missing graph_exec_id"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
continue;
};
let channel = format!("{user_id}|graph_exec#{graph_exec_id}");
debug!("📌 Client {} subscribing to channel: {}", client_id, channel);
{
let mut subs = mgr.subscribers.write().await;
subs.entry(channel.clone()).or_insert(std::collections::HashSet::new()).insert(client_id);
}
{
let mut chs = mgr.client_channels.write().await;
if let Some(set) = chs.get_mut(&client_id) {
set.insert(channel.clone());
}
}
// Update subscription stats
update_subscription_stats(&mgr, &channel, true).await;
let resp = WSMessage {
method: "subscribe_graph_execution".to_string(),
success: Some(true),
channel: Some(channel),
..Default::default()
};
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
break;
}
}
"subscribe_graph_executions" => {
let graph_id = match &ws_msg.data {
Some(Value::Object(map)) => map.get("graph_id").and_then(|v| v.as_str()),
_ => None,
};
let Some(graph_id) = graph_id else {
let err_resp = json!({"method": "error", "success": false, "error": "Missing graph_id"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
continue;
};
let channel = format!("{user_id}|graph#{graph_id}|executions");
{
let mut subs = mgr.subscribers.write().await;
subs.entry(channel.clone()).or_insert(std::collections::HashSet::new()).insert(client_id);
}
{
let mut chs = mgr.client_channels.write().await;
if let Some(set) = chs.get_mut(&client_id) {
set.insert(channel.clone());
}
}
debug!("📌 Client {} subscribing to channel: {}", client_id, channel);
// Update subscription stats
update_subscription_stats(&mgr, &channel, true).await;
let resp = WSMessage {
method: "subscribe_graph_executions".to_string(),
success: Some(true),
channel: Some(channel),
..Default::default()
};
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
break;
}
}
"unsubscribe" => {
let channel = match &ws_msg.data {
Some(Value::String(s)) => Some(s.as_str()),
Some(Value::Object(map)) => map.get("channel").and_then(|v| v.as_str()),
_ => None,
};
let Some(channel) = channel else {
let err_resp = json!({"method": "error", "success": false, "error": "Missing channel"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
continue;
};
let channel = channel.to_string();
if !channel.starts_with(&format!("{user_id}|")) {
let err_resp = json!({"method": "error", "success": false, "error": "Unauthorized channel"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
continue;
}
{
let mut subs = mgr.subscribers.write().await;
if let Some(set) = subs.get_mut(&channel) {
set.remove(&client_id);
if set.is_empty() {
subs.remove(&channel);
}
}
}
{
let mut chs = mgr.client_channels.write().await;
if let Some(set) = chs.get_mut(&client_id) {
set.remove(&channel);
}
}
// Update subscription stats
update_subscription_stats(&mgr, &channel, false).await;
let resp = WSMessage {
method: "unsubscribe".to_string(),
success: Some(true),
channel: Some(channel),
..Default::default()
};
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
break;
}
}
"heartbeat" => {
if ws_msg.data == Some(Value::String("ping".to_string())) {
let resp = WSMessage {
method: "heartbeat".to_string(),
data: Some(Value::String("pong".to_string())),
success: Some(true),
..Default::default()
};
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
break;
}
} else {
let err_resp = json!({"method": "error", "success": false, "error": "Invalid heartbeat"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
}
}
_ => {
warn!("❓ Unknown method '{}' from client {}", ws_msg.method, client_id);
let err_resp = json!({"method": "error", "success": false, "error": "Unknown method"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
}
}
}
Message::Close(_) => break,
Message::Ping(_) => {
if socket.send(Message::Pong(vec![])).await.is_err() {
break;
}
}
Message::Pong(_) => {}
_ => {}
}
}
else => break,
}
}
// Cleanup
debug!("👋 WebSocket client {} disconnected, cleaning up", client_id);
// Update connection stats
mgr.stats
.connections_active
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
// Update active users
{
let mut active_users = mgr.stats.active_users.write().await;
if let Some(count) = active_users.get_mut(&user_id) {
*count = count.saturating_sub(1);
if *count == 0 {
active_users.remove(&user_id);
}
}
}
let channels = {
let mut client_channels = mgr.client_channels.write().await;
client_channels.remove(&client_id).unwrap_or_default()
};
{
let mut subs = mgr.subscribers.write().await;
for channel in &channels {
if let Some(set) = subs.get_mut(channel) {
set.remove(&client_id);
if set.is_empty() {
subs.remove(channel);
}
}
}
}
// Update subscription stats for all channels the client was subscribed to
for channel in &channels {
update_subscription_stats(&mgr, channel, false).await;
}
{
let mut clients = mgr.clients.write().await;
clients.remove(&client_id);
}
debug!("✨ Cleanup completed for client {}", client_id);
}

View File

@@ -0,0 +1,26 @@
#![deny(warnings)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
#![deny(clippy::unimplemented)]
#![deny(clippy::todo)]
pub mod config;
pub mod connection_manager;
pub mod handlers;
pub mod models;
pub mod stats;
pub use config::Config;
pub use connection_manager::ConnectionManager;
pub use handlers::ws_handler;
pub use stats::Stats;
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub mgr: Arc<ConnectionManager>,
pub config: Arc<Config>,
pub stats: Arc<Stats>,
}

View File

@@ -0,0 +1,172 @@
use axum::{
body::Body,
http::{header, StatusCode},
response::Response,
routing::get,
Router,
};
use clap::Parser;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::cors::{Any, CorsLayer};
use tracing::{debug, error, info};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use crate::config::Config;
use crate::connection_manager::ConnectionManager;
use crate::handlers::ws_handler;
async fn stats_handler(
axum::Extension(state): axum::Extension<AppState>,
) -> Result<axum::response::Json<stats::StatsSnapshot>, StatusCode> {
let snapshot = state.stats.snapshot().await;
Ok(axum::response::Json(snapshot))
}
async fn prometheus_handler(
axum::Extension(state): axum::Extension<AppState>,
) -> Result<Response, StatusCode> {
let snapshot = state.stats.snapshot().await;
let prometheus_text = state.stats.to_prometheus_format(&snapshot);
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain; version=0.0.4")
.body(Body::from(prometheus_text))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
mod config;
mod connection_manager;
mod handlers;
mod models;
mod stats;
#[derive(Parser, Debug)]
#[command(author, version, about)]
struct Cli {
/// Path to a TOML configuration file
#[arg(short = 'c', long = "config", value_name = "FILE")]
config: Option<std::path::PathBuf>,
}
#[derive(Clone)]
pub struct AppState {
mgr: Arc<ConnectionManager>,
config: Arc<Config>,
stats: Arc<stats::Stats>,
}
#[tokio::main]
async fn main() {
// Initialize tracing
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "websocket=info,tower_http=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
info!("🚀 Starting WebSocket API server");
let cli = Cli::parse();
let config = Arc::new(Config::load(cli.config.as_deref()));
info!(
"⚙️ Configuration loaded - host: {}, port: {}, auth: {}",
config.host, config.port, config.enable_auth
);
let redis_client = match redis::Client::open(config.redis_url.clone()) {
Ok(client) => {
debug!("✅ Redis client created successfully");
client
}
Err(e) => {
error!(
"❌ Failed to create Redis client: {}. Please check REDIS_URL environment variable",
e
);
std::process::exit(1);
}
};
let stats = Arc::new(stats::Stats::default());
let mgr = Arc::new(ConnectionManager::new(
redis_client,
config.execution_event_bus_name.clone(),
stats.clone(),
));
let mgr_clone = mgr.clone();
tokio::spawn(async move {
debug!("📡 Starting event broadcaster task");
mgr_clone.run_broadcaster().await;
});
let state = AppState {
mgr,
config: config.clone(),
stats,
};
let app = Router::new()
.route("/ws", get(ws_handler))
.route("/stats", get(stats_handler))
.route("/metrics", get(prometheus_handler))
.layer(axum::Extension(state));
let cors = if config.backend_cors_allow_origins.is_empty() {
// If no specific origins configured, allow any origin but without credentials
CorsLayer::new()
.allow_methods(Any)
.allow_headers(Any)
.allow_origin(Any)
} else {
// If specific origins configured, allow credentials
CorsLayer::new()
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::PUT,
axum::http::Method::DELETE,
axum::http::Method::OPTIONS,
])
.allow_headers(vec![
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
])
.allow_credentials(true)
.allow_origin(
config
.backend_cors_allow_origins
.iter()
.filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
.collect::<Vec<_>>(),
)
};
let app = app.layer(cors);
let addr = format!("{}:{}", config.host, config.port);
let listener = match TcpListener::bind(&addr).await {
Ok(listener) => {
info!("🎧 WebSocket server listening on: {}", addr);
listener
}
Err(e) => {
error!(
"❌ Failed to bind to {}: {}. Please check if the port is already in use",
addr, e
);
std::process::exit(1);
}
};
info!("✨ WebSocket API server ready to accept connections");
if let Err(e) = axum::serve(listener, app.into_make_service()).await {
error!("💥 Server error: {}", e);
std::process::exit(1);
}
}

View File

@@ -0,0 +1,103 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct WSMessage {
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub success: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub channel: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
#[derive(Deserialize)]
pub struct Claims {
pub sub: String,
}
// Event models moved from events.rs
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "event_type")]
pub enum ExecutionEvent {
#[serde(rename = "graph_execution_update")]
GraphExecutionUpdate(GraphExecutionEvent),
#[serde(rename = "node_execution_update")]
NodeExecutionUpdate(NodeExecutionEvent),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphExecutionEvent {
pub id: String,
pub graph_id: String,
pub graph_version: u32,
pub user_id: String,
pub status: ExecutionStatus,
pub started_at: Option<String>,
pub ended_at: Option<String>,
pub preset_id: Option<String>,
pub stats: Option<ExecutionStats>,
// Keep these as JSON since they vary by graph
pub inputs: Value,
pub outputs: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeExecutionEvent {
pub node_exec_id: String,
pub node_id: String,
pub graph_exec_id: String,
pub graph_id: String,
pub graph_version: u32,
pub user_id: String,
pub block_id: String,
pub status: ExecutionStatus,
pub add_time: String,
pub queue_time: Option<String>,
pub start_time: Option<String>,
pub end_time: Option<String>,
// Keep these as JSON since they vary by node type
pub input_data: Value,
pub output_data: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionStats {
pub cost: f64,
pub duration: f64,
pub duration_cpu_only: f64,
pub error: Option<String>,
pub node_error_count: u32,
pub node_exec_count: u32,
pub node_exec_time: f64,
pub node_exec_time_cpu_only: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ExecutionStatus {
Queued,
Running,
Completed,
Failed,
Incomplete,
Terminated,
}
// Wrapper for the Redis event that includes the payload
#[derive(Debug, Deserialize)]
pub struct RedisEventWrapper {
pub payload: ExecutionEvent,
}
impl RedisEventWrapper {
pub fn parse(json_str: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json_str)
}
}

View File

@@ -0,0 +1,238 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
#[derive(Default)]
pub struct Stats {
// Connection metrics
pub connections_total: AtomicU64,
pub connections_active: AtomicU64,
pub connections_failed_auth: AtomicU64,
// Message metrics
pub messages_received_total: AtomicU64,
pub messages_sent_total: AtomicU64,
pub messages_failed_total: AtomicU64,
// Subscription metrics
pub subscriptions_total: AtomicU64,
pub subscriptions_active: AtomicU64,
pub unsubscriptions_total: AtomicU64,
// Event metrics by type
pub events_received_total: AtomicU64,
pub graph_execution_events: AtomicU64,
pub node_execution_events: AtomicU64,
// Redis metrics
pub redis_messages_received: AtomicU64,
pub redis_messages_ignored: AtomicU64,
// Channel metrics
pub channels_active: RwLock<HashMap<String, usize>>, // channel -> subscriber count
// User metrics
pub active_users: RwLock<HashMap<String, usize>>, // user_id -> connection count
// Error metrics
pub errors_total: AtomicU64,
pub errors_json_parse: AtomicU64,
pub errors_message_size: AtomicU64,
}
#[derive(Serialize, Deserialize)]
pub struct StatsSnapshot {
// Connection metrics
pub connections_total: u64,
pub connections_active: u64,
pub connections_failed_auth: u64,
// Message metrics
pub messages_received_total: u64,
pub messages_sent_total: u64,
pub messages_failed_total: u64,
// Subscription metrics
pub subscriptions_total: u64,
pub subscriptions_active: u64,
pub unsubscriptions_total: u64,
// Event metrics
pub events_received_total: u64,
pub graph_execution_events: u64,
pub node_execution_events: u64,
// Redis metrics
pub redis_messages_received: u64,
pub redis_messages_ignored: u64,
// Channel metrics
pub channels_active_count: usize,
pub total_subscribers: usize,
// User metrics
pub active_users_count: usize,
// Error metrics
pub errors_total: u64,
pub errors_json_parse: u64,
pub errors_message_size: u64,
}
impl Stats {
pub async fn snapshot(&self) -> StatsSnapshot {
// Take read locks for HashMap data - it's ok if this is slightly stale
let channels = self.channels_active.read().await;
let total_subscribers: usize = channels.values().sum();
let channels_active_count = channels.len();
drop(channels); // Release lock early
let users = self.active_users.read().await;
let active_users_count = users.len();
drop(users); // Release lock early
StatsSnapshot {
connections_total: self.connections_total.load(Ordering::Relaxed),
connections_active: self.connections_active.load(Ordering::Relaxed),
connections_failed_auth: self.connections_failed_auth.load(Ordering::Relaxed),
messages_received_total: self.messages_received_total.load(Ordering::Relaxed),
messages_sent_total: self.messages_sent_total.load(Ordering::Relaxed),
messages_failed_total: self.messages_failed_total.load(Ordering::Relaxed),
subscriptions_total: self.subscriptions_total.load(Ordering::Relaxed),
subscriptions_active: self.subscriptions_active.load(Ordering::Relaxed),
unsubscriptions_total: self.unsubscriptions_total.load(Ordering::Relaxed),
events_received_total: self.events_received_total.load(Ordering::Relaxed),
graph_execution_events: self.graph_execution_events.load(Ordering::Relaxed),
node_execution_events: self.node_execution_events.load(Ordering::Relaxed),
redis_messages_received: self.redis_messages_received.load(Ordering::Relaxed),
redis_messages_ignored: self.redis_messages_ignored.load(Ordering::Relaxed),
channels_active_count,
total_subscribers,
active_users_count,
errors_total: self.errors_total.load(Ordering::Relaxed),
errors_json_parse: self.errors_json_parse.load(Ordering::Relaxed),
errors_message_size: self.errors_message_size.load(Ordering::Relaxed),
}
}
pub fn to_prometheus_format(&self, snapshot: &StatsSnapshot) -> String {
let mut output = String::new();
// Connection metrics
output.push_str("# HELP ws_connections_total Total number of WebSocket connections\n");
output.push_str("# TYPE ws_connections_total counter\n");
output.push_str(&format!(
"ws_connections_total {}\n\n",
snapshot.connections_total
));
output.push_str(
"# HELP ws_connections_active Current number of active WebSocket connections\n",
);
output.push_str("# TYPE ws_connections_active gauge\n");
output.push_str(&format!(
"ws_connections_active {}\n\n",
snapshot.connections_active
));
output
.push_str("# HELP ws_connections_failed_auth Total number of failed authentications\n");
output.push_str("# TYPE ws_connections_failed_auth counter\n");
output.push_str(&format!(
"ws_connections_failed_auth {}\n\n",
snapshot.connections_failed_auth
));
// Message metrics
output.push_str(
"# HELP ws_messages_received_total Total number of messages received from clients\n",
);
output.push_str("# TYPE ws_messages_received_total counter\n");
output.push_str(&format!(
"ws_messages_received_total {}\n\n",
snapshot.messages_received_total
));
output.push_str("# HELP ws_messages_sent_total Total number of messages sent to clients\n");
output.push_str("# TYPE ws_messages_sent_total counter\n");
output.push_str(&format!(
"ws_messages_sent_total {}\n\n",
snapshot.messages_sent_total
));
// Subscription metrics
output.push_str("# HELP ws_subscriptions_active Current number of active subscriptions\n");
output.push_str("# TYPE ws_subscriptions_active gauge\n");
output.push_str(&format!(
"ws_subscriptions_active {}\n\n",
snapshot.subscriptions_active
));
// Event metrics
output.push_str(
"# HELP ws_events_received_total Total number of events received from Redis\n",
);
output.push_str("# TYPE ws_events_received_total counter\n");
output.push_str(&format!(
"ws_events_received_total {}\n\n",
snapshot.events_received_total
));
output.push_str(
"# HELP ws_graph_execution_events_total Total number of graph execution events\n",
);
output.push_str("# TYPE ws_graph_execution_events_total counter\n");
output.push_str(&format!(
"ws_graph_execution_events_total {}\n\n",
snapshot.graph_execution_events
));
output.push_str(
"# HELP ws_node_execution_events_total Total number of node execution events\n",
);
output.push_str("# TYPE ws_node_execution_events_total counter\n");
output.push_str(&format!(
"ws_node_execution_events_total {}\n\n",
snapshot.node_execution_events
));
// Channel metrics
output.push_str("# HELP ws_channels_active Number of active channels\n");
output.push_str("# TYPE ws_channels_active gauge\n");
output.push_str(&format!(
"ws_channels_active {}\n\n",
snapshot.channels_active_count
));
output.push_str(
"# HELP ws_total_subscribers Total number of subscribers across all channels\n",
);
output.push_str("# TYPE ws_total_subscribers gauge\n");
output.push_str(&format!(
"ws_total_subscribers {}\n\n",
snapshot.total_subscribers
));
// User metrics
output.push_str("# HELP ws_active_users Number of unique users with active connections\n");
output.push_str("# TYPE ws_active_users gauge\n");
output.push_str(&format!(
"ws_active_users {}\n\n",
snapshot.active_users_count
));
// Error metrics
output.push_str("# HELP ws_errors_total Total number of errors\n");
output.push_str("# TYPE ws_errors_total counter\n");
output.push_str(&format!("ws_errors_total {}\n", snapshot.errors_total));
output
}
}

View File

@@ -7,5 +7,9 @@ class Settings:
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
self.JWT_ALGORITHM: str = "HS256"
@property
def is_configured(self) -> bool:
return bool(self.JWT_SECRET_KEY)
settings = Settings()

View File

@@ -0,0 +1,166 @@
import asyncio
import contextlib
import logging
from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
import ldclient
from fastapi import HTTPException
from ldclient import Context, LDClient
from ldclient.config import Config
from typing_extensions import ParamSpec
from .config import SETTINGS
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
def get_client() -> LDClient:
"""Get the LaunchDarkly client singleton."""
return ldclient.get()
def initialize_launchdarkly() -> None:
sdk_key = SETTINGS.launch_darkly_sdk_key
logger.debug(
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
)
if not sdk_key:
logger.warning("LaunchDarkly SDK key not configured")
return
config = Config(sdk_key)
ldclient.set_config(config)
if ldclient.get().is_initialized():
logger.info("LaunchDarkly client initialized successfully")
else:
logger.error("LaunchDarkly client failed to initialize")
def shutdown_launchdarkly() -> None:
"""Shutdown the LaunchDarkly client."""
if ldclient.get().is_initialized():
ldclient.get().close()
logger.info("LaunchDarkly client closed successfully")
def create_context(
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
) -> Context:
"""Create LaunchDarkly context with optional additional attributes."""
builder = Context.builder(str(user_id)).kind("user")
if additional_attributes:
for key, value in additional_attributes.items():
builder.set(key, value)
return builder.build()
def feature_flag(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""
Decorator for feature flag protected endpoints.
"""
def decorator(
func: Callable[P, Union[T, Awaitable[T]]],
) -> Callable[P, Union[T, Awaitable[T]]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
result = func(*args, **kwargs)
if asyncio.iscoroutine(result):
return await result
return cast(T, result)
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
@wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
return cast(T, func(*args, **kwargs))
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
return cast(
Callable[P, Union[T, Awaitable[T]]],
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
)
return decorator
def percentage_rollout(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for percentage-based rollouts."""
return feature_flag(flag_key, default)
def beta_feature(
flag_key: Optional[str] = None,
unauthorized_response: Any = {"message": "Not available in beta"},
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for beta features."""
actual_key = f"beta-{flag_key}" if flag_key else "beta"
return feature_flag(actual_key, False)
@contextlib.contextmanager
def mock_flag_variation(flag_key: str, return_value: Any):
"""Context manager for testing feature flags."""
original_variation = get_client().variation
get_client().variation = lambda key, context, default: (
return_value if key == flag_key else original_variation(key, context, default)
)
try:
yield
finally:
get_client().variation = original_variation

View File

@@ -0,0 +1,45 @@
import pytest
from ldclient import LDClient
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
@pytest.fixture
def ld_client(mocker):
client = mocker.Mock(spec=LDClient)
mocker.patch("ldclient.get", return_value=client)
client.is_initialized.return_value = True
return client
@pytest.mark.asyncio
async def test_feature_flag_enabled(ld_client):
ld_client.variation.return_value = True
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
result = test_function(user_id="test-user")
assert result == "success"
ld_client.variation.assert_called_once()
@pytest.mark.asyncio
async def test_feature_flag_unauthorized_response(ld_client):
ld_client.variation.return_value = False
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
result = test_function(user_id="test-user")
assert result == {"error": "disabled"}
def test_mock_flag_variation(ld_client):
with mock_flag_variation("test-flag", True):
assert ld_client.variation("test-flag", None, False)
with mock_flag_variation("test-flag", False):
assert ld_client.variation("test-flag", None, False)

View File

@@ -0,0 +1,15 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
launch_darkly_sdk_key: str = Field(
default="",
description="The Launch Darkly SDK key",
validation_alias="LAUNCH_DARKLY_SDK_KEY",
)
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
SETTINGS = Settings()

View File

@@ -1,8 +1,6 @@
"""Logging module for Auto-GPT."""
import logging
import os
import socket
import sys
from pathlib import Path
@@ -12,15 +10,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from .filters import BelowLevelFilter
from .formatters import AGPTFormatter
# Configure global socket timeout and gRPC keepalive to prevent deadlocks
# This must be done at import time before any gRPC connections are established
socket.setdefaulttimeout(30) # 30-second socket timeout
# Enable gRPC keepalive to detect dead connections faster
os.environ.setdefault("GRPC_KEEPALIVE_TIME_MS", "30000") # 30 seconds
os.environ.setdefault("GRPC_KEEPALIVE_TIMEOUT_MS", "5000") # 5 seconds
os.environ.setdefault("GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", "true")
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
LOG_FILE = "activity.log"
DEBUG_LOG_FILE = "debug.log"
@@ -90,6 +79,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
Note: This function is typically called at the start of the application
to set up the logging infrastructure.
"""
config = LoggingConfig()
log_handlers: list[logging.Handler] = []
@@ -115,17 +105,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
if config.enable_cloud_logging or force_cloud_logging:
import google.cloud.logging
from google.cloud.logging.handlers import CloudLoggingHandler
from google.cloud.logging_v2.handlers.transports import (
BackgroundThreadTransport,
)
from google.cloud.logging_v2.handlers.transports.sync import SyncTransport
client = google.cloud.logging.Client()
# Use BackgroundThreadTransport to prevent blocking the main thread
# and deadlocks when gRPC calls to Google Cloud Logging hang
cloud_handler = CloudLoggingHandler(
client,
name="autogpt_logs",
transport=BackgroundThreadTransport,
transport=SyncTransport,
)
cloud_handler.setLevel(config.level)
log_handlers.append(cloud_handler)

View File

@@ -1,5 +1,39 @@
import logging
import re
from typing import Any
import uvicorn.config
from colorama import Fore
def remove_color_codes(s: str) -> str:
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
def fmt_kwargs(kwargs: dict) -> str:
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
def print_attribute(
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
) -> None:
logger = logging.getLogger()
logger.info(
str(value),
extra={
"title": f"{title.rstrip(':')}:",
"title_color": title_color,
"color": value_color,
},
)
def generate_uvicorn_config():
"""
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
"""
log_config = dict(uvicorn.config.LOGGING_CONFIG)
log_config["loggers"]["uvicorn"] = {"handlers": []}
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
return log_config

View File

@@ -1,34 +1,17 @@
import inspect
import logging
import threading
import time
from functools import wraps
from typing import (
Awaitable,
Callable,
ParamSpec,
Protocol,
Tuple,
TypeVar,
cast,
overload,
runtime_checkable,
)
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
P = ParamSpec("P")
R = TypeVar("R")
logger = logging.getLogger(__name__)
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
pass
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
pass
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
def thread_cached(
@@ -74,193 +57,3 @@ def thread_cached(
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()
FuncT = TypeVar("FuncT")
R_co = TypeVar("R_co", covariant=True)
@runtime_checkable
class AsyncCachedFunction(Protocol[P, R_co]):
"""Protocol for async functions with cache management methods."""
def cache_clear(self) -> None:
"""Clear all cached entries."""
return None
def cache_info(self) -> dict[str, int | None]:
"""Get cache statistics."""
return {}
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
"""Call the cached function."""
return None # type: ignore
def async_ttl_cache(
maxsize: int = 128, ttl_seconds: int | None = None
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
"""
TTL (Time To Live) cache decorator for async functions.
Similar to functools.lru_cache but works with async functions and includes optional TTL.
Args:
maxsize: Maximum number of cached entries
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
Returns:
Decorator function
Example:
# With TTL
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
async def api_call(param: str) -> dict:
return {"result": param}
# Without TTL (permanent cache like lru_cache)
@async_ttl_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
return {"result": param}
"""
def decorator(
async_func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
# Cache storage - use union type to handle both cases
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
@wraps(async_func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# Create cache key from arguments
key = (args, tuple(sorted(kwargs.items())))
current_time = time.time()
# Check if we have a valid cached entry
if key in cache_storage:
if ttl_seconds is None:
# No TTL - return cached result directly
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, cache_storage[key])
else:
# With TTL - check expiration
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, result)
else:
# Expired entry
del cache_storage[key]
logger.debug(
f"Cache entry expired for {async_func.__name__}"
)
# Cache miss or expired - fetch fresh data
logger.debug(
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
)
result = await async_func(*args, **kwargs)
# Store in cache
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Simple cleanup when cache gets too large
if len(cache_storage) > maxsize:
# Remove oldest entries (simple FIFO cleanup)
cutoff = maxsize // 2
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
logger.debug(
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
)
return result
# Add cache management methods (similar to functools.lru_cache)
def cache_clear() -> None:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
# Attach methods to wrapper
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)
return cast(AsyncCachedFunction[P, R], wrapper)
return decorator
@overload
def async_cache(
func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
pass
@overload
def async_cache(
func: None = None,
*,
maxsize: int = 128,
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
pass
def async_cache(
func: Callable[P, Awaitable[R]] | None = None,
*,
maxsize: int = 128,
) -> (
AsyncCachedFunction[P, R]
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
):
"""
Process-level cache decorator for async functions (no TTL).
Similar to functools.lru_cache but works with async functions.
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
Args:
func: The async function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
Returns:
Decorated function or decorator
Example:
# Without parentheses (uses default maxsize=128)
@async_cache
async def get_data(param: str) -> dict:
return {"result": param}
# With parentheses and custom maxsize
@async_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
# Expensive computation here
return {"result": param}
"""
if func is None:
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
else:
# Called without parentheses @async_cache
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
return decorator(func)

View File

@@ -16,12 +16,7 @@ from unittest.mock import Mock
import pytest
from autogpt_libs.utils.cache import (
async_cache,
async_ttl_cache,
clear_thread_cache,
thread_cached,
)
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
class TestThreadCached:
@@ -328,378 +323,3 @@ class TestThreadCached:
assert function_using_mock(2) == 42
assert mock.call_count == 2
class TestAsyncTTLCache:
"""Tests for the @async_ttl_cache decorator."""
@pytest.mark.asyncio
async def test_basic_caching(self):
"""Test basic caching functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Different args - should call function again
result3 = await cached_function(2, 3)
assert result3 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_ttl_expiration(self):
"""Test that cache entries expire after TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def short_lived_cache(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
# First call
result1 = await short_lived_cache(5)
assert result1 == 10
assert call_count == 1
# Second call immediately - should use cache
result2 = await short_lived_cache(5)
assert result2 == 10
assert call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Third call after expiration - should call function again
result3 = await short_lived_cache(5)
assert result3 == 10
assert call_count == 2
@pytest.mark.asyncio
async def test_cache_info(self):
"""Test cache info functionality."""
@async_ttl_cache(maxsize=5, ttl_seconds=300)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] == 300
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
@pytest.mark.asyncio
async def test_cache_clear(self):
"""Test cache clearing functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 4
# First call
result1 = await clearable_function(2)
assert result1 == 8
assert call_count == 1
# Second call - should use cache
result2 = await clearable_function(2)
assert result2 == 8
assert call_count == 1
# Clear cache
clearable_function.cache_clear()
# Third call after clear - should call function again
result3 = await clearable_function(2)
assert result3 == 8
assert call_count == 2
@pytest.mark.asyncio
async def test_maxsize_cleanup(self):
"""Test that cache cleans up when maxsize is exceeded."""
call_count = 0
@async_ttl_cache(maxsize=3, ttl_seconds=60)
async def size_limited_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# Fill cache to maxsize
await size_limited_function(1) # call_count: 1
await size_limited_function(2) # call_count: 2
await size_limited_function(3) # call_count: 3
info = size_limited_function.cache_info()
assert info["size"] == 3
# Add one more entry - should trigger cleanup
await size_limited_function(4) # call_count: 4
# Cache size should be reduced (cleanup removes oldest entries)
info = size_limited_function.cache_info()
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
@pytest.mark.asyncio
async def test_argument_variations(self):
"""Test caching with different argument patterns."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
nonlocal call_count
call_count += 1
return f"{a}-{b}-{c}"
# Different ways to call with same logical arguments
result1 = await arg_test_function(1, "test", c=200)
assert call_count == 1
# Same arguments, same order - should use cache
result2 = await arg_test_function(1, "test", c=200)
assert call_count == 1
assert result1 == result2
# Different arguments - should call function
result3 = await arg_test_function(2, "test", c=200)
assert call_count == 2
assert result1 != result3
@pytest.mark.asyncio
async def test_exception_handling(self):
"""Test that exceptions are not cached."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def exception_function(x: int) -> int:
nonlocal call_count
call_count += 1
if x < 0:
raise ValueError("Negative value not allowed")
return x * 2
# Successful call - should be cached
result1 = await exception_function(5)
assert result1 == 10
assert call_count == 1
# Same successful call - should use cache
result2 = await exception_function(5)
assert result2 == 10
assert call_count == 1
# Exception call - should not be cached
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 2
# Same exception call - should call again (not cached)
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 3
@pytest.mark.asyncio
async def test_concurrent_calls(self):
"""Test caching behavior with concurrent calls."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def concurrent_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.05) # Simulate work
return x * x
# Launch concurrent calls with same arguments
tasks = [concurrent_function(3) for _ in range(5)]
results = await asyncio.gather(*tasks)
# All results should be the same
assert all(result == 9 for result in results)
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
# This tests that the cache doesn't break under concurrent access
assert 1 <= call_count <= 5
class TestAsyncCache:
"""Tests for the @async_cache decorator (no TTL)."""
@pytest.mark.asyncio
async def test_basic_caching_no_ttl(self):
"""Test basic caching functionality without TTL."""
call_count = 0
@async_cache(maxsize=10)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Third call after some time - should still use cache (no TTL)
await asyncio.sleep(0.05)
result3 = await cached_function(1, 2)
assert result3 == 3
assert call_count == 1 # Still no additional call
# Different args - should call function again
result4 = await cached_function(2, 3)
assert result4 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_no_ttl_vs_ttl_behavior(self):
"""Test the difference between TTL and no-TTL caching."""
ttl_call_count = 0
no_ttl_call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_call_count
ttl_call_count += 1
return x * 2
@async_cache(maxsize=10) # No TTL
async def no_ttl_function(x: int) -> int:
nonlocal no_ttl_call_count
no_ttl_call_count += 1
return x * 2
# First calls
await ttl_function(5)
await no_ttl_function(5)
assert ttl_call_count == 1
assert no_ttl_call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Second calls after TTL expiry
await ttl_function(5) # Should call function again (TTL expired)
await no_ttl_function(5) # Should use cache (no TTL)
assert ttl_call_count == 2 # TTL function called again
assert no_ttl_call_count == 1 # No-TTL function still cached
@pytest.mark.asyncio
async def test_async_cache_info(self):
"""Test cache info for no-TTL cache."""
@async_cache(maxsize=5)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] is None # No TTL
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
class TestTTLOptional:
"""Tests for optional TTL functionality."""
@pytest.mark.asyncio
async def test_ttl_none_behavior(self):
"""Test that ttl_seconds=None works like no TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=None)
async def no_ttl_via_none(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# First call
result1 = await no_ttl_via_none(3)
assert result1 == 9
assert call_count == 1
# Wait (would expire if there was TTL)
await asyncio.sleep(0.1)
# Second call - should still use cache
result2 = await no_ttl_via_none(3)
assert result2 == 9
assert call_count == 1 # No additional call
# Check cache info
info = no_ttl_via_none.cache_info()
assert info["ttl_seconds"] is None
@pytest.mark.asyncio
async def test_cache_options_comparison(self):
"""Test different cache options work as expected."""
ttl_calls = 0
no_ttl_calls = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_calls
ttl_calls += 1
return x * 10
@async_cache(maxsize=10) # Process-level cache (no TTL)
async def process_function(x: int) -> int:
nonlocal no_ttl_calls
no_ttl_calls += 1
return x * 10
# Both should cache initially
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Immediate second calls - both should use cache
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# After TTL expiry
await ttl_function(3) # Should call function again
await process_function(3) # Should still use cache
assert ttl_calls == 2 # TTL cache expired, called again
assert no_ttl_calls == 1 # Process cache never expires

View File

@@ -1,52 +0,0 @@
# Development and testing files
**/__pycache__
**/*.pyc
**/*.pyo
**/*.pyd
**/.Python
**/env/
**/venv/
**/.venv/
**/pip-log.txt
**/.pytest_cache/
**/test-results/
**/snapshots/
**/test/
# IDE and editor files
**/.vscode/
**/.idea/
**/*.swp
**/*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Logs
**/*.log
**/logs/
# Git
.git/
.gitignore
# Documentation
**/*.md
!README.md
# Local development files
.env
.env.local
**/.env.test
# Build artifacts
**/dist/
**/build/
**/target/
# Docker files (avoid recursion)
Dockerfile*
docker-compose*
.dockerignore

View File

@@ -1,9 +1,3 @@
# Backend Configuration
# This file contains environment variables that MUST be set for the AutoGPT platform
# Variables with working defaults in settings.py are not included here
## ===== REQUIRED DATABASE CONFIGURATION ===== ##
# PostgreSQL Database Connection
DB_USER=postgres
DB_PASS=your-super-secret-and-long-postgres-password
DB_NAME=postgres
@@ -16,50 +10,72 @@ DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
PRISMA_SCHEMA="postgres/schema.prisma"
ENABLE_AUTH=true
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
# Redis Configuration
# EXECUTOR
NUM_GRAPH_WORKERS=10
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=password
# RabbitMQ Credentials
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
ENABLE_CREDIT=false
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Supabase Authentication
# What environment things should be logged under: local dev or prod
APP_ENV=local
# What environment to behave as: "local" or "cloud"
BEHAVE_AS=local
PYRO_HOST=localhost
SENTRY_DSN=
# Email For Postmark so we can send emails
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
ENABLE_AUTH=true
SUPABASE_URL=http://localhost:8000
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
## ===== REQUIRED SECURITY KEYS ===== ##
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
# RabbitMQ credentials -- Used for communication between services
RABBITMQ_HOST=localhost
RABBITMQ_PORT=5672
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000
FRONTEND_BASE_URL=http://localhost:3000
# Media Storage (required for marketplace and library functionality)
## GCS bucket is required for marketplace and library functionality
MEDIA_GCS_BUCKET_NAME=
## ===== API KEYS AND OAUTH CREDENTIALS ===== ##
# All API keys below are optional - only add what you need
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
# FRONTEND_BASE_URL=http://localhost:3000
# AI/LLM Services
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
GROQ_API_KEY=
LLAMA_API_KEY=
AIML_API_KEY=
V0_API_KEY=
OPEN_ROUTER_API_KEY=
NVIDIA_API_KEY=
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
## to use the platform's webhook-related functionality.
## If you are developing locally, you can use something like ngrok to get a publc URL
## and tunnel it to your locally running backend.
PLATFORM_BASE_URL=http://localhost:3000
## Cloudflare Turnstile (CAPTCHA) Configuration
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
## This is the backend secret key
TURNSTILE_SECRET_KEY=
## This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
## == INTEGRATION CREDENTIALS == ##
# Each set of server side credentials is required for the corresponding 3rd party
# integration to work.
# OAuth Credentials
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
# e.g. http://localhost:3000/auth/integrations/oauth_callback
@@ -69,6 +85,7 @@ GITHUB_CLIENT_SECRET=
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
# You'll need to add/enable the following scopes (minimum):
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
@@ -104,66 +121,104 @@ LINEAR_CLIENT_SECRET=
TODOIST_CLIENT_ID=
TODOIST_CLIENT_SECRET=
NOTION_CLIENT_ID=
NOTION_CLIENT_SECRET=
## ===== OPTIONAL API KEYS ===== ##
# LLM
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
AIML_API_KEY=
GROQ_API_KEY=
OPEN_ROUTER_API_KEY=
LLAMA_API_KEY=
# Reddit
# Go to https://www.reddit.com/prefs/apps and create a new app
# Choose "script" for the type
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
REDDIT_CLIENT_ID=
REDDIT_CLIENT_SECRET=
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
# Payment Processing
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Email Service (for sending notifications and confirmations)
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
# Error Tracking
SENTRY_DSN=
# Cloudflare Turnstile (CAPTCHA) Configuration
# Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
# This is the backend secret key
TURNSTILE_SECRET_KEY=
# This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
# Feature Flags
LAUNCH_DARKLY_SDK_KEY=
# Content Generation & Media
DID_API_KEY=
FAL_API_KEY=
IDEOGRAM_API_KEY=
REPLICATE_API_KEY=
REVID_API_KEY=
SCREENSHOTONE_API_KEY=
UNREAL_SPEECH_API_KEY=
# Data & Search Services
E2B_API_KEY=
EXA_API_KEY=
JINA_API_KEY=
MEM0_API_KEY=
OPENWEATHERMAP_API_KEY=
GOOGLE_MAPS_API_KEY=
# Communication Services
# Discord
DISCORD_BOT_TOKEN=
MEDIUM_API_KEY=
MEDIUM_AUTHOR_ID=
# SMTP/Email
SMTP_SERVER=
SMTP_PORT=
SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
# D-ID
DID_API_KEY=
# Open Weather Map
OPENWEATHERMAP_API_KEY=
# SMTP
SMTP_SERVER=
SMTP_PORT=
SMTP_USERNAME=
SMTP_PASSWORD=
# Medium
MEDIUM_API_KEY=
MEDIUM_AUTHOR_ID=
# Google Maps
GOOGLE_MAPS_API_KEY=
# Replicate
REPLICATE_API_KEY=
# Ideogram
IDEOGRAM_API_KEY=
# Fal
FAL_API_KEY=
# Exa
EXA_API_KEY=
# E2B
E2B_API_KEY=
# Mem0
MEM0_API_KEY=
# Nvidia
NVIDIA_API_KEY=
# Apollo
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=
AYRSHARE_JWT_KEY=
# SmartLead
SMARTLEAD_API_KEY=
# ZeroBounce
ZEROBOUNCE_API_KEY=
# Other Services
AUTOMOD_API_KEY=
# Ayrshare
AYRSHARE_API_KEY=
AYRSHARE_JWT_KEY=
## ===== OPTIONAL API KEYS END ===== ##
# Block Error Rate Monitoring
BLOCK_ERROR_RATE_THRESHOLD=0.5
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400
# Logging Configuration
LOG_LEVEL=INFO
ENABLE_CLOUD_LOGGING=false
ENABLE_FILE_LOGGING=false
# Use to manually set the log directory
# LOG_DIR=./logs
# Example Blocks Configuration
# Set to true to enable example blocks in development
# These blocks are disabled by default in production
ENABLE_EXAMPLE_BLOCKS=false
# Cloud Storage Configuration
# Cleanup interval for expired files (hours between cleanup runs, 1-24 hours)
CLOUD_STORAGE_CLEANUP_INTERVAL_HOURS=6

View File

@@ -1,4 +1,3 @@
.env
database.db
database.db-journal
dev.db

View File

@@ -8,14 +8,14 @@ WORKDIR /app
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
# Update package list and install build dependencies in a single layer
RUN apt-get update --allow-releaseinfo-change --fix-missing \
&& apt-get install -y \
build-essential \
libpq5 \
libz-dev \
libssl-dev \
postgresql-client
RUN apt-get update --allow-releaseinfo-change --fix-missing
# Install build dependencies
RUN apt-get install -y build-essential
RUN apt-get install -y libpq5
RUN apt-get install -y libz-dev
RUN apt-get install -y libssl-dev
RUN apt-get install -y postgresql-client
ENV POETRY_HOME=/opt/poetry
ENV POETRY_NO_INTERACTION=1
@@ -68,12 +68,6 @@ COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.tom
WORKDIR /app/autogpt_platform/backend
FROM server_dependencies AS migrate
# Migration stage only needs schema and migrations - much lighter than full backend
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend

View File

@@ -43,11 +43,11 @@ def main(**kwargs):
run_processes(
DatabaseManager().set_log_level("warning"),
ExecutionManager(),
Scheduler(),
NotificationManager(),
WebsocketServer(),
AgentServer(),
ExecutionManager(),
**kwargs,
)

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Any, Optional
@@ -14,8 +15,7 @@ from backend.data.block import (
)
from backend.data.execution import ExecutionStatus
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry
from backend.util import json, retry
_logger = logging.getLogger(__name__)
@@ -49,7 +49,7 @@ class AgentExecutorBlock(Block):
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return validate_with_jsonschema(cls.get_input_schema(data), data)
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
pass
@@ -95,14 +95,23 @@ class AgentExecutorBlock(Block):
logger=logger,
):
yield name, data
except BaseException as e:
except asyncio.CancelledError:
await self._stop(
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
)
logger.warning(
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e.__class__.__name__} {str(e)}; execution is stopped."
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} was cancelled."
)
except Exception as e:
await self._stop(
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
)
logger.error(
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e}, execution is stopped."
)
raise
@@ -122,7 +131,6 @@ class AgentExecutorBlock(Block):
log_id = f"Graph #{graph_id}-V{graph_version}, exec-id: {graph_exec_id}"
logger.info(f"Starting execution of {log_id}")
yielded_node_exec_ids = set()
async for event in event_bus.listen(
user_id=user_id,
@@ -154,14 +162,6 @@ class AgentExecutorBlock(Block):
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
if event.node_exec_id in yielded_node_exec_ids:
logger.warning(
f"{log_id} received duplicate event for node execution {event.node_exec_id}"
)
continue
else:
yielded_node_exec_ids.add(event.node_exec_id)
if not event.block_id:
logger.warning(f"{log_id} received event without block_id {event}")
continue
@@ -181,7 +181,7 @@ class AgentExecutorBlock(Block):
)
yield output_name, output_data
@func_retry
@retry.func_retry
async def _stop(
self,
graph_exec_id: str,
@@ -197,8 +197,7 @@ class AgentExecutorBlock(Block):
await execution_utils.stop_graph_execution(
graph_exec_id=graph_exec_id,
user_id=user_id,
wait_timeout=3600,
)
logger.info(f"Execution {log_id} stopped successfully.")
except TimeoutError as e:
logger.error(f"Execution {log_id} stop timed out: {e}")
except Exception as e:
logger.error(f"Failed to stop execution {log_id}: {e}")

View File

@@ -9,24 +9,6 @@ from backend.sdk import BaseModel, Credentials, Requests
logger = getLogger(__name__)
def _convert_bools(
obj: Any,
) -> Any: # noqa: ANN401 allow Any for deep conversion utility
"""Recursively walk *obj* and coerce string booleans to real booleans."""
if isinstance(obj, str):
lowered = obj.lower()
if lowered == "true":
return True
if lowered == "false":
return False
return obj
if isinstance(obj, list):
return [_convert_bools(item) for item in obj]
if isinstance(obj, dict):
return {k: _convert_bools(v) for k, v in obj.items()}
return obj
class WebhookFilters(BaseModel):
dataTypes: list[str]
changeTypes: list[str] | None = None
@@ -597,7 +579,7 @@ async def update_table(
response = await Requests().patch(
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables/{table_id}",
headers={"Authorization": credentials.auth_header()},
json=_convert_bools(params),
json=params,
)
return response.json()
@@ -627,7 +609,7 @@ async def create_field(
response = await Requests().post(
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables/{table_id}/fields",
headers={"Authorization": credentials.auth_header()},
json=_convert_bools(params),
json=params,
)
return response.json()
@@ -651,7 +633,7 @@ async def update_field(
response = await Requests().patch(
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables/{table_id}/fields/{field_id}",
headers={"Authorization": credentials.auth_header()},
json=_convert_bools(params),
json=params,
)
return response.json()
@@ -709,7 +691,7 @@ async def list_records(
response = await Requests().get(
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}",
headers={"Authorization": credentials.auth_header()},
json=_convert_bools(params),
params=params,
)
return response.json()
@@ -738,22 +720,20 @@ async def update_multiple_records(
typecast: bool | None = None,
) -> dict[str, dict[str, dict[str, str]]]:
params: dict[
str, str | bool | dict[str, list[str]] | list[dict[str, dict[str, str]]]
] = {}
params: dict[str, str | dict[str, list[str]] | list[dict[str, dict[str, str]]]] = {}
if perform_upsert:
params["performUpsert"] = perform_upsert
if return_fields_by_field_id:
params["returnFieldsByFieldId"] = str(return_fields_by_field_id)
if typecast:
params["typecast"] = typecast
params["typecast"] = str(typecast)
params["records"] = [_convert_bools(record) for record in records]
params["records"] = records
response = await Requests().patch(
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}",
headers={"Authorization": credentials.auth_header()},
json=_convert_bools(params),
json=params,
)
return response.json()
@@ -767,20 +747,18 @@ async def update_record(
typecast: bool | None = None,
fields: dict[str, Any] | None = None,
) -> dict[str, dict[str, dict[str, str]]]:
params: dict[str, str | bool | dict[str, Any] | list[dict[str, dict[str, str]]]] = (
{}
)
params: dict[str, str | dict[str, Any] | list[dict[str, dict[str, str]]]] = {}
if return_fields_by_field_id:
params["returnFieldsByFieldId"] = return_fields_by_field_id
params["returnFieldsByFieldId"] = str(return_fields_by_field_id)
if typecast:
params["typecast"] = typecast
params["typecast"] = str(typecast)
if fields:
params["fields"] = fields
response = await Requests().patch(
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}/{record_id}",
headers={"Authorization": credentials.auth_header()},
json=_convert_bools(params),
json=params,
)
return response.json()
@@ -801,22 +779,21 @@ async def create_record(
len(records) <= 10
), "Only up to 10 records can be provided when using records"
params: dict[str, str | bool | dict[str, Any] | list[dict[str, Any]]] = {}
params: dict[str, str | dict[str, Any] | list[dict[str, Any]]] = {}
if fields:
params["fields"] = fields
if records:
params["records"] = records
if return_fields_by_field_id:
params["returnFieldsByFieldId"] = return_fields_by_field_id
params["returnFieldsByFieldId"] = str(return_fields_by_field_id)
if typecast:
params["typecast"] = typecast
params["typecast"] = str(typecast)
response = await Requests().post(
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}",
headers={"Authorization": credentials.auth_header()},
json=_convert_bools(params),
json=params,
)
return response.json()
@@ -873,7 +850,7 @@ async def create_webhook(
response = await Requests().post(
f"https://api.airtable.com/v0/bases/{base_id}/webhooks",
headers={"Authorization": credentials.auth_header()},
json=_convert_bools(params),
json=params,
)
return response.json()
@@ -1218,7 +1195,7 @@ async def create_base(
"Authorization": credentials.auth_header(),
"Content-Type": "application/json",
},
json=_convert_bools(params),
json=params,
)
return response.json()

View File

@@ -159,7 +159,6 @@ class AirtableOAuthHandler(BaseOAuthHandler):
logger.info("Successfully refreshed tokens")
new_credentials = OAuth2Credentials(
id=credentials.id,
access_token=SecretStr(response.access_token),
refresh_token=SecretStr(response.refresh_token),
access_token_expires_at=int(time.time()) + response.expires_in,

View File

@@ -4,19 +4,11 @@ from typing import Optional
from pydantic import BaseModel, Field
from backend.data.block import BlockSchema
from backend.data.model import SchemaField, UserIntegrations
from backend.data.model import SchemaField
from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import MissingConfigError
async def get_profile_key(user_id: str):
user_integrations: UserIntegrations = (
await get_database_manager_async_client().get_user_integrations(user_id)
)
return user_integrations.managed_credentials.ayrshare_profile_key
class BaseAyrshareInput(BlockSchema):
"""Base input model for Ayrshare social media posts with common fields."""

View File

@@ -6,9 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._util import BaseAyrshareInput, create_ayrshare_client
class PostToBlueskyBlock(Block):
@@ -57,12 +58,10 @@ class PostToBlueskyBlock(Block):
self,
input_data: "PostToBlueskyBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to Bluesky with Bluesky-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -6,14 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import (
BaseAyrshareInput,
CarouselItem,
create_ayrshare_client,
get_profile_key,
)
from ._util import BaseAyrshareInput, CarouselItem, create_ayrshare_client
class PostToFacebookBlock(Block):
@@ -120,11 +116,10 @@ class PostToFacebookBlock(Block):
self,
input_data: "PostToFacebookBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to Facebook with Facebook-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -6,9 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._util import BaseAyrshareInput, create_ayrshare_client
class PostToGMBBlock(Block):
@@ -110,10 +111,9 @@ class PostToGMBBlock(Block):
)
async def run(
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
self, input_data: "PostToGMBBlock.Input", *, profile_key: SecretStr, **kwargs
) -> BlockOutput:
"""Post to Google My Business with GMB-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -8,14 +8,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import (
BaseAyrshareInput,
InstagramUserTag,
create_ayrshare_client,
get_profile_key,
)
from ._util import BaseAyrshareInput, InstagramUserTag, create_ayrshare_client
class PostToInstagramBlock(Block):
@@ -112,11 +108,10 @@ class PostToInstagramBlock(Block):
self,
input_data: "PostToInstagramBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to Instagram with Instagram-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -6,9 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._util import BaseAyrshareInput, create_ayrshare_client
class PostToLinkedInBlock(Block):
@@ -112,11 +113,10 @@ class PostToLinkedInBlock(Block):
self,
input_data: "PostToLinkedInBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to LinkedIn with LinkedIn-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -6,14 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import (
BaseAyrshareInput,
PinterestCarouselOption,
create_ayrshare_client,
get_profile_key,
)
from ._util import BaseAyrshareInput, PinterestCarouselOption, create_ayrshare_client
class PostToPinterestBlock(Block):
@@ -92,11 +88,10 @@ class PostToPinterestBlock(Block):
self,
input_data: "PostToPinterestBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to Pinterest with Pinterest-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -6,9 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._util import BaseAyrshareInput, create_ayrshare_client
class PostToRedditBlock(Block):
@@ -35,9 +36,8 @@ class PostToRedditBlock(Block):
)
async def run(
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
self, input_data: "PostToRedditBlock.Input", *, profile_key: SecretStr, **kwargs
) -> BlockOutput:
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -6,9 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._util import BaseAyrshareInput, create_ayrshare_client
class PostToSnapchatBlock(Block):
@@ -62,11 +63,10 @@ class PostToSnapchatBlock(Block):
self,
input_data: "PostToSnapchatBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to Snapchat with Snapchat-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -6,9 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._util import BaseAyrshareInput, create_ayrshare_client
class PostToTelegramBlock(Block):
@@ -57,11 +58,10 @@ class PostToTelegramBlock(Block):
self,
input_data: "PostToTelegramBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to Telegram with Telegram-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -6,9 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._util import BaseAyrshareInput, create_ayrshare_client
class PostToThreadsBlock(Block):
@@ -50,11 +51,10 @@ class PostToThreadsBlock(Block):
self,
input_data: "PostToThreadsBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to Threads with Threads-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -1,5 +1,3 @@
from enum import Enum
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
@@ -8,15 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class TikTokVisibility(str, Enum):
PUBLIC = "public"
PRIVATE = "private"
FOLLOWERS = "followers"
from ._util import BaseAyrshareInput, create_ayrshare_client
class PostToTikTokBlock(Block):
@@ -28,6 +21,7 @@ class PostToTikTokBlock(Block):
# Override post field to include TikTok-specific information
post: str = SchemaField(
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
default="",
advanced=False,
)
@@ -40,7 +34,7 @@ class PostToTikTokBlock(Block):
# TikTok-specific options
auto_add_music: bool = SchemaField(
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
description="Automatically add recommended music to image posts",
default=False,
advanced=True,
)
@@ -60,17 +54,17 @@ class PostToTikTokBlock(Block):
advanced=True,
)
is_ai_generated: bool = SchemaField(
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and cant be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
description="Label content as AI-generated (video only)",
default=False,
advanced=True,
)
is_branded_content: bool = SchemaField(
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
description="Label as branded content (paid partnership)",
default=False,
advanced=True,
)
is_brand_organic: bool = SchemaField(
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
description="Label as brand organic content (promotional)",
default=False,
advanced=True,
)
@@ -87,9 +81,9 @@ class PostToTikTokBlock(Block):
default=0,
advanced=True,
)
visibility: TikTokVisibility = SchemaField(
visibility: str = SchemaField(
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
default=TikTokVisibility.PUBLIC,
default="public",
advanced=True,
)
draft: bool = SchemaField(
@@ -104,6 +98,7 @@ class PostToTikTokBlock(Block):
def __init__(self):
super().__init__(
disabled=True,
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
description="Post to TikTok using Ayrshare",
categories={BlockCategory.SOCIAL},
@@ -113,10 +108,9 @@ class PostToTikTokBlock(Block):
)
async def run(
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
self, input_data: "PostToTikTokBlock.Input", *, profile_key: SecretStr, **kwargs
) -> BlockOutput:
"""Post to TikTok with TikTok-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
@@ -166,6 +160,12 @@ class PostToTikTokBlock(Block):
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
return
# Validate visibility option
valid_visibility = ["public", "private", "followers", "friends"]
if input_data.visibility not in valid_visibility:
yield "error", f"TikTok visibility must be one of: {', '.join(valid_visibility)}"
return
# Check for PNG files (not supported)
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
if has_png:
@@ -218,8 +218,8 @@ class PostToTikTokBlock(Block):
if input_data.title:
tiktok_options["title"] = input_data.title
if input_data.visibility != TikTokVisibility.PUBLIC:
tiktok_options["visibility"] = input_data.visibility.value
if input_data.visibility != "public":
tiktok_options["visibility"] = input_data.visibility
response = await client.create_post(
post=input_data.post,

View File

@@ -6,9 +6,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._util import BaseAyrshareInput, create_ayrshare_client
class PostToXBlock(Block):
@@ -115,11 +116,10 @@ class PostToXBlock(Block):
self,
input_data: "PostToXBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to X / Twitter with enhanced X-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

View File

@@ -9,9 +9,10 @@ from backend.sdk import (
BlockSchema,
BlockType,
SchemaField,
SecretStr,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
from ._util import BaseAyrshareInput, create_ayrshare_client
class YouTubeVisibility(str, Enum):
@@ -137,12 +138,10 @@ class PostToYouTubeBlock(Block):
self,
input_data: "PostToYouTubeBlock.Input",
*,
user_id: str,
profile_key: SecretStr,
**kwargs,
) -> BlockOutput:
"""Post to YouTube with YouTube-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return

File diff suppressed because it is too large Load Diff

View File

@@ -1,408 +0,0 @@
"""
API module for Enrichlayer integration.
This module provides a client for interacting with the Enrichlayer API,
which allows fetching LinkedIn profile data and related information.
"""
import datetime
import enum
import logging
from json import JSONDecodeError
from typing import Any, Optional, TypeVar
from pydantic import BaseModel, Field
from backend.data.model import APIKeyCredentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
T = TypeVar("T")
class EnrichlayerAPIException(Exception):
"""Exception raised for Enrichlayer API errors."""
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class FallbackToCache(enum.Enum):
ON_ERROR = "on-error"
NEVER = "never"
class UseCache(enum.Enum):
IF_PRESENT = "if-present"
NEVER = "never"
class SocialMediaProfiles(BaseModel):
"""Social media profiles model."""
twitter: Optional[str] = None
facebook: Optional[str] = None
github: Optional[str] = None
class Experience(BaseModel):
"""Experience model for LinkedIn profiles."""
company: Optional[str] = None
title: Optional[str] = None
description: Optional[str] = None
location: Optional[str] = None
starts_at: Optional[dict[str, int]] = None
ends_at: Optional[dict[str, int]] = None
company_linkedin_profile_url: Optional[str] = None
class Education(BaseModel):
"""Education model for LinkedIn profiles."""
school: Optional[str] = None
degree_name: Optional[str] = None
field_of_study: Optional[str] = None
starts_at: Optional[dict[str, int]] = None
ends_at: Optional[dict[str, int]] = None
school_linkedin_profile_url: Optional[str] = None
class PersonProfileResponse(BaseModel):
"""Response model for LinkedIn person profile.
This model represents the response from Enrichlayer's LinkedIn profile API.
The API returns comprehensive profile data including work experience,
education, skills, and contact information (when available).
Example API Response:
{
"public_identifier": "johnsmith",
"full_name": "John Smith",
"occupation": "Software Engineer at Tech Corp",
"experiences": [
{
"company": "Tech Corp",
"title": "Software Engineer",
"starts_at": {"year": 2020, "month": 1}
}
],
"education": [...],
"skills": ["Python", "JavaScript", ...]
}
"""
public_identifier: Optional[str] = None
profile_pic_url: Optional[str] = None
full_name: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
occupation: Optional[str] = None
headline: Optional[str] = None
summary: Optional[str] = None
country: Optional[str] = None
country_full_name: Optional[str] = None
city: Optional[str] = None
state: Optional[str] = None
experiences: Optional[list[Experience]] = None
education: Optional[list[Education]] = None
languages: Optional[list[str]] = None
skills: Optional[list[str]] = None
inferred_salary: Optional[dict[str, Any]] = None
personal_email: Optional[str] = None
personal_contact_number: Optional[str] = None
social_media_profiles: Optional[SocialMediaProfiles] = None
extra: Optional[dict[str, Any]] = None
class SimilarProfile(BaseModel):
"""Similar profile model for LinkedIn person lookup."""
similarity: float
linkedin_profile_url: str
class PersonLookupResponse(BaseModel):
"""Response model for LinkedIn person lookup.
This model represents the response from Enrichlayer's person lookup API.
The API returns a LinkedIn profile URL and similarity scores when
searching for a person by name and company.
Example API Response:
{
"url": "https://www.linkedin.com/in/johnsmith/",
"name_similarity_score": 0.95,
"company_similarity_score": 0.88,
"title_similarity_score": 0.75,
"location_similarity_score": 0.60
}
"""
url: str | None = None
name_similarity_score: float | None
company_similarity_score: float | None
title_similarity_score: float | None
location_similarity_score: float | None
last_updated: datetime.datetime | None = None
profile: PersonProfileResponse | None = None
class RoleLookupResponse(BaseModel):
"""Response model for LinkedIn role lookup.
This model represents the response from Enrichlayer's role lookup API.
The API returns LinkedIn profile data for a specific role at a company.
Example API Response:
{
"linkedin_profile_url": "https://www.linkedin.com/in/johnsmith/",
"profile_data": {...} // Full PersonProfileResponse data when enrich_profile=True
}
"""
linkedin_profile_url: Optional[str] = None
profile_data: Optional[PersonProfileResponse] = None
class ProfilePictureResponse(BaseModel):
"""Response model for LinkedIn profile picture.
This model represents the response from Enrichlayer's profile picture API.
The API returns a URL to the person's LinkedIn profile picture.
Example API Response:
{
"tmp_profile_pic_url": "https://media.licdn.com/dms/image/..."
}
"""
tmp_profile_pic_url: str = Field(
..., description="URL of the profile picture", alias="tmp_profile_pic_url"
)
@property
def profile_picture_url(self) -> str:
"""Backward compatibility property for profile_picture_url."""
return self.tmp_profile_pic_url
class EnrichlayerClient:
"""Client for interacting with the Enrichlayer API."""
API_BASE_URL = "https://enrichlayer.com/api/v2"
def __init__(
self,
credentials: Optional[APIKeyCredentials] = None,
custom_requests: Optional[Requests] = None,
):
"""
Initialize the Enrichlayer client.
Args:
credentials: The credentials to use for authentication.
custom_requests: Custom Requests instance for testing.
"""
if custom_requests:
self._requests = custom_requests
else:
headers: dict[str, str] = {
"Content-Type": "application/json",
}
if credentials:
headers["Authorization"] = (
f"Bearer {credentials.api_key.get_secret_value()}"
)
self._requests = Requests(
extra_headers=headers,
raise_for_status=False,
)
async def _handle_response(self, response) -> Any:
"""
Handle API response and check for errors.
Args:
response: The response object from the request.
Returns:
The response data.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
if not response.ok:
try:
error_data = response.json()
error_message = error_data.get("message", "")
except JSONDecodeError:
error_message = response.text
raise EnrichlayerAPIException(
f"Enrichlayer API request failed ({response.status_code}): {error_message}",
response.status_code,
)
return response.json()
async def fetch_profile(
self,
linkedin_url: str,
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
use_cache: UseCache = UseCache.IF_PRESENT,
include_skills: bool = False,
include_inferred_salary: bool = False,
include_personal_email: bool = False,
include_personal_contact_number: bool = False,
include_social_media: bool = False,
include_extra: bool = False,
) -> PersonProfileResponse:
"""
Fetch a LinkedIn profile with optional parameters.
Args:
linkedin_url: The LinkedIn profile URL to fetch.
fallback_to_cache: Cache usage if live fetch fails ('on-error' or 'never').
use_cache: Cache utilization ('if-present' or 'never').
include_skills: Whether to include skills data.
include_inferred_salary: Whether to include inferred salary data.
include_personal_email: Whether to include personal email.
include_personal_contact_number: Whether to include personal contact number.
include_social_media: Whether to include social media profiles.
include_extra: Whether to include additional data.
Returns:
The LinkedIn profile data.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"url": linkedin_url,
"fallback_to_cache": fallback_to_cache.value.lower(),
"use_cache": use_cache.value.lower(),
}
if include_skills:
params["skills"] = "include"
if include_inferred_salary:
params["inferred_salary"] = "include"
if include_personal_email:
params["personal_email"] = "include"
if include_personal_contact_number:
params["personal_contact_number"] = "include"
if include_social_media:
params["twitter_profile_id"] = "include"
params["facebook_profile_id"] = "include"
params["github_profile_id"] = "include"
if include_extra:
params["extra"] = "include"
response = await self._requests.get(
f"{self.API_BASE_URL}/profile", params=params
)
return PersonProfileResponse(**await self._handle_response(response))
async def lookup_person(
self,
first_name: str,
company_domain: str,
last_name: str | None = None,
location: Optional[str] = None,
title: Optional[str] = None,
include_similarity_checks: bool = False,
enrich_profile: bool = False,
) -> PersonLookupResponse:
"""
Look up a LinkedIn profile by person's information.
Args:
first_name: The person's first name.
last_name: The person's last name.
company_domain: The domain of the company they work for.
location: The person's location.
title: The person's job title.
include_similarity_checks: Whether to include similarity checks.
enrich_profile: Whether to enrich the profile.
Returns:
The LinkedIn profile lookup result.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {"first_name": first_name, "company_domain": company_domain}
if last_name:
params["last_name"] = last_name
if location:
params["location"] = location
if title:
params["title"] = title
if include_similarity_checks:
params["similarity_checks"] = "include"
if enrich_profile:
params["enrich_profile"] = "enrich"
response = await self._requests.get(
f"{self.API_BASE_URL}/profile/resolve", params=params
)
return PersonLookupResponse(**await self._handle_response(response))
async def lookup_role(
self, role: str, company_name: str, enrich_profile: bool = False
) -> RoleLookupResponse:
"""
Look up a LinkedIn profile by role in a company.
Args:
role: The role title (e.g., CEO, CTO).
company_name: The name of the company.
enrich_profile: Whether to enrich the profile.
Returns:
The LinkedIn profile lookup result.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"role": role,
"company_name": company_name,
}
if enrich_profile:
params["enrich_profile"] = "enrich"
response = await self._requests.get(
f"{self.API_BASE_URL}/find/company/role", params=params
)
return RoleLookupResponse(**await self._handle_response(response))
async def get_profile_picture(
self, linkedin_profile_url: str
) -> ProfilePictureResponse:
"""
Get a LinkedIn profile picture URL.
Args:
linkedin_profile_url: The LinkedIn profile URL.
Returns:
The profile picture URL.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"linkedin_person_profile_url": linkedin_profile_url,
}
response = await self._requests.get(
f"{self.API_BASE_URL}/person/profile-picture", params=params
)
return ProfilePictureResponse(**await self._handle_response(response))

View File

@@ -1,34 +0,0 @@
"""
Authentication module for Enrichlayer API integration.
This module provides credential types and test credentials for the Enrichlayer API.
"""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
from backend.integrations.providers import ProviderName
# Define the type of credentials input expected for Enrichlayer API
EnrichlayerCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.ENRICHLAYER], Literal["api_key"]
]
# Mock credentials for testing Enrichlayer API integration
TEST_CREDENTIALS = APIKeyCredentials(
id="1234a567-89bc-4def-ab12-3456cdef7890",
provider="enrichlayer",
api_key=SecretStr("mock-enrichlayer-api-key"),
title="Mock Enrichlayer API key",
expires_at=None,
)
# Dictionary representation of test credentials for input fields
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

View File

@@ -1,527 +0,0 @@
"""
Block definitions for Enrichlayer API integration.
This module implements blocks for interacting with the Enrichlayer API,
which provides access to LinkedIn profile data and related information.
"""
import logging
from typing import Optional
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.util.type import MediaFileType
from ._api import (
EnrichlayerClient,
Experience,
FallbackToCache,
PersonLookupResponse,
PersonProfileResponse,
RoleLookupResponse,
UseCache,
)
from ._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, EnrichlayerCredentialsInput
logger = logging.getLogger(__name__)
class GetLinkedinProfileBlock(Block):
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for GetLinkedinProfileBlock."""
linkedin_url: str = SchemaField(
description="LinkedIn profile URL to fetch data from",
placeholder="https://www.linkedin.com/in/username/",
)
fallback_to_cache: FallbackToCache = SchemaField(
description="Cache usage if live fetch fails",
default=FallbackToCache.ON_ERROR,
advanced=True,
)
use_cache: UseCache = SchemaField(
description="Cache utilization strategy",
default=UseCache.IF_PRESENT,
advanced=True,
)
include_skills: bool = SchemaField(
description="Include skills data",
default=False,
advanced=True,
)
include_inferred_salary: bool = SchemaField(
description="Include inferred salary data",
default=False,
advanced=True,
)
include_personal_email: bool = SchemaField(
description="Include personal email",
default=False,
advanced=True,
)
include_personal_contact_number: bool = SchemaField(
description="Include personal contact number",
default=False,
advanced=True,
)
include_social_media: bool = SchemaField(
description="Include social media profiles",
default=False,
advanced=True,
)
include_extra: bool = SchemaField(
description="Include additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for GetLinkedinProfileBlock."""
profile: PersonProfileResponse = SchemaField(
description="LinkedIn profile data"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize GetLinkedinProfileBlock."""
super().__init__(
id="f6e0ac73-4f1d-4acb-b4b7-b67066c5984e",
description="Fetch LinkedIn profile data using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=GetLinkedinProfileBlock.Input,
output_schema=GetLinkedinProfileBlock.Output,
test_input={
"linkedin_url": "https://www.linkedin.com/in/williamhgates/",
"include_skills": True,
"include_social_media": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"profile",
PersonProfileResponse(
public_identifier="williamhgates",
full_name="Bill Gates",
occupation="Co-chair at Bill & Melinda Gates Foundation",
experiences=[
Experience(
company="Bill & Melinda Gates Foundation",
title="Co-chair",
starts_at={"year": 2000},
)
],
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_fetch_profile": lambda *args, **kwargs: PersonProfileResponse(
public_identifier="williamhgates",
full_name="Bill Gates",
occupation="Co-chair at Bill & Melinda Gates Foundation",
experiences=[
Experience(
company="Bill & Melinda Gates Foundation",
title="Co-chair",
starts_at={"year": 2000},
)
],
),
},
)
@staticmethod
async def _fetch_profile(
credentials: APIKeyCredentials,
linkedin_url: str,
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
use_cache: UseCache = UseCache.IF_PRESENT,
include_skills: bool = False,
include_inferred_salary: bool = False,
include_personal_email: bool = False,
include_personal_contact_number: bool = False,
include_social_media: bool = False,
include_extra: bool = False,
):
client = EnrichlayerClient(credentials)
profile = await client.fetch_profile(
linkedin_url=linkedin_url,
fallback_to_cache=fallback_to_cache,
use_cache=use_cache,
include_skills=include_skills,
include_inferred_salary=include_inferred_salary,
include_personal_email=include_personal_email,
include_personal_contact_number=include_personal_contact_number,
include_social_media=include_social_media,
include_extra=include_extra,
)
return profile
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to fetch LinkedIn profile data.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
profile = await self._fetch_profile(
credentials=credentials,
linkedin_url=input_data.linkedin_url,
fallback_to_cache=input_data.fallback_to_cache,
use_cache=input_data.use_cache,
include_skills=input_data.include_skills,
include_inferred_salary=input_data.include_inferred_salary,
include_personal_email=input_data.include_personal_email,
include_personal_contact_number=input_data.include_personal_contact_number,
include_social_media=input_data.include_social_media,
include_extra=input_data.include_extra,
)
yield "profile", profile
except Exception as e:
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
yield "error", str(e)
class LinkedinPersonLookupBlock(Block):
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for LinkedinPersonLookupBlock."""
first_name: str = SchemaField(
description="Person's first name",
placeholder="John",
advanced=False,
)
last_name: str | None = SchemaField(
description="Person's last name",
placeholder="Doe",
default=None,
advanced=False,
)
company_domain: str = SchemaField(
description="Domain of the company they work for (optional)",
placeholder="example.com",
advanced=False,
)
location: Optional[str] = SchemaField(
description="Person's location (optional)",
placeholder="San Francisco",
default=None,
)
title: Optional[str] = SchemaField(
description="Person's job title (optional)",
placeholder="CEO",
default=None,
)
include_similarity_checks: bool = SchemaField(
description="Include similarity checks",
default=False,
advanced=True,
)
enrich_profile: bool = SchemaField(
description="Enrich the profile with additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for LinkedinPersonLookupBlock."""
lookup_result: PersonLookupResponse = SchemaField(
description="LinkedIn profile lookup result"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize LinkedinPersonLookupBlock."""
super().__init__(
id="d237a98a-5c4b-4a1c-b9e3-e6f9a6c81df7",
description="Look up LinkedIn profiles by person information using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=LinkedinPersonLookupBlock.Input,
output_schema=LinkedinPersonLookupBlock.Output,
test_input={
"first_name": "Bill",
"last_name": "Gates",
"company_domain": "gatesfoundation.org",
"include_similarity_checks": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"lookup_result",
PersonLookupResponse(
url="https://www.linkedin.com/in/williamhgates/",
name_similarity_score=0.93,
company_similarity_score=0.83,
title_similarity_score=0.3,
location_similarity_score=0.20,
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_lookup_person": lambda *args, **kwargs: PersonLookupResponse(
url="https://www.linkedin.com/in/williamhgates/",
name_similarity_score=0.93,
company_similarity_score=0.83,
title_similarity_score=0.3,
location_similarity_score=0.20,
)
},
)
@staticmethod
async def _lookup_person(
credentials: APIKeyCredentials,
first_name: str,
company_domain: str,
last_name: str | None = None,
location: Optional[str] = None,
title: Optional[str] = None,
include_similarity_checks: bool = False,
enrich_profile: bool = False,
):
client = EnrichlayerClient(credentials=credentials)
lookup_result = await client.lookup_person(
first_name=first_name,
last_name=last_name,
company_domain=company_domain,
location=location,
title=title,
include_similarity_checks=include_similarity_checks,
enrich_profile=enrich_profile,
)
return lookup_result
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to look up LinkedIn profiles.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
lookup_result = await self._lookup_person(
credentials=credentials,
first_name=input_data.first_name,
last_name=input_data.last_name,
company_domain=input_data.company_domain,
location=input_data.location,
title=input_data.title,
include_similarity_checks=input_data.include_similarity_checks,
enrich_profile=input_data.enrich_profile,
)
yield "lookup_result", lookup_result
except Exception as e:
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
yield "error", str(e)
class LinkedinRoleLookupBlock(Block):
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for LinkedinRoleLookupBlock."""
role: str = SchemaField(
description="Role title (e.g., CEO, CTO)",
placeholder="CEO",
)
company_name: str = SchemaField(
description="Name of the company",
placeholder="Microsoft",
)
enrich_profile: bool = SchemaField(
description="Enrich the profile with additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for LinkedinRoleLookupBlock."""
role_lookup_result: RoleLookupResponse = SchemaField(
description="LinkedIn role lookup result"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize LinkedinRoleLookupBlock."""
super().__init__(
id="3b9fc742-06d4-49c7-b5ce-7e302dd7c8a7",
description="Look up LinkedIn profiles by role in a company using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=LinkedinRoleLookupBlock.Input,
output_schema=LinkedinRoleLookupBlock.Output,
test_input={
"role": "Co-chair",
"company_name": "Gates Foundation",
"enrich_profile": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"role_lookup_result",
RoleLookupResponse(
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_lookup_role": lambda *args, **kwargs: RoleLookupResponse(
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
),
},
)
@staticmethod
async def _lookup_role(
credentials: APIKeyCredentials,
role: str,
company_name: str,
enrich_profile: bool = False,
):
client = EnrichlayerClient(credentials=credentials)
role_lookup_result = await client.lookup_role(
role=role,
company_name=company_name,
enrich_profile=enrich_profile,
)
return role_lookup_result
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to look up LinkedIn profiles by role.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
role_lookup_result = await self._lookup_role(
credentials=credentials,
role=input_data.role,
company_name=input_data.company_name,
enrich_profile=input_data.enrich_profile,
)
yield "role_lookup_result", role_lookup_result
except Exception as e:
logger.error(f"Error looking up role in company: {str(e)}")
yield "error", str(e)
class GetLinkedinProfilePictureBlock(Block):
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for GetLinkedinProfilePictureBlock."""
linkedin_profile_url: str = SchemaField(
description="LinkedIn profile URL",
placeholder="https://www.linkedin.com/in/username/",
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for GetLinkedinProfilePictureBlock."""
profile_picture_url: MediaFileType = SchemaField(
description="LinkedIn profile picture URL"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize GetLinkedinProfilePictureBlock."""
super().__init__(
id="68d5a942-9b3f-4e9a-b7c1-d96ea4321f0d",
description="Get LinkedIn profile pictures using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=GetLinkedinProfilePictureBlock.Input,
output_schema=GetLinkedinProfilePictureBlock.Output,
test_input={
"linkedin_profile_url": "https://www.linkedin.com/in/williamhgates/",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"profile_picture_url",
"https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_get_profile_picture": lambda *args, **kwargs: "https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
},
)
@staticmethod
async def _get_profile_picture(
credentials: APIKeyCredentials, linkedin_profile_url: str
):
client = EnrichlayerClient(credentials=credentials)
profile_picture_response = await client.get_profile_picture(
linkedin_profile_url=linkedin_profile_url,
)
return profile_picture_response.profile_picture_url
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to get LinkedIn profile pictures.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
profile_picture = await self._get_profile_picture(
credentials=credentials,
linkedin_profile_url=input_data.linkedin_profile_url,
)
yield "profile_picture_url", profile_picture
except Exception as e:
logger.error(f"Error getting profile picture: {str(e)}")
yield "error", str(e)

View File

@@ -1,247 +0,0 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# Enum definitions based on available options
class WebsetStatus(str, Enum):
IDLE = "idle"
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
class WebsetSearchStatus(str, Enum):
CREATED = "created"
# Add more if known, based on example it's "created"
class ImportStatus(str, Enum):
PENDING = "pending"
# Add more if known
class ImportFormat(str, Enum):
CSV = "csv"
# Add more if known
class EnrichmentStatus(str, Enum):
PENDING = "pending"
# Add more if known
class EnrichmentFormat(str, Enum):
TEXT = "text"
# Add more if known
class MonitorStatus(str, Enum):
ENABLED = "enabled"
# Add more if known
class MonitorBehaviorType(str, Enum):
SEARCH = "search"
# Add more if known
class MonitorRunStatus(str, Enum):
CREATED = "created"
# Add more if known
class CanceledReason(str, Enum):
WEBSET_DELETED = "webset_deleted"
# Add more if known
class FailedReason(str, Enum):
INVALID_FORMAT = "invalid_format"
# Add more if known
class Confidence(str, Enum):
HIGH = "high"
# Add more if known
# Nested models
class Entity(BaseModel):
type: str
class Criterion(BaseModel):
description: str
successRate: Optional[int] = None
class ExcludeItem(BaseModel):
source: str = Field(default="import")
id: str
class Relationship(BaseModel):
definition: str
limit: Optional[float] = None
class ScopeItem(BaseModel):
source: str = Field(default="import")
id: str
relationship: Optional[Relationship] = None
class Progress(BaseModel):
found: int
analyzed: int
completion: int
timeLeft: int
class Bounds(BaseModel):
min: int
max: int
class Expected(BaseModel):
total: int
confidence: str = Field(default="high") # Use str or Confidence enum
bounds: Bounds
class Recall(BaseModel):
expected: Expected
reasoning: str
class WebsetSearch(BaseModel):
id: str
object: str = Field(default="webset_search")
status: str = Field(default="created") # Or use WebsetSearchStatus
websetId: str
query: str
entity: Entity
criteria: List[Criterion]
count: int
behavior: str = Field(default="override")
exclude: List[ExcludeItem]
scope: List[ScopeItem]
progress: Progress
recall: Recall
metadata: Dict[str, Any] = Field(default_factory=dict)
canceledAt: Optional[datetime] = None
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
createdAt: datetime
updatedAt: datetime
class ImportEntity(BaseModel):
type: str
class Import(BaseModel):
id: str
object: str = Field(default="import")
status: str = Field(default="pending") # Or use ImportStatus
format: str = Field(default="csv") # Or use ImportFormat
entity: ImportEntity
title: str
count: int
metadata: Dict[str, Any] = Field(default_factory=dict)
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
failedAt: Optional[datetime] = None
failedMessage: Optional[str] = None
createdAt: datetime
updatedAt: datetime
class Option(BaseModel):
label: str
class WebsetEnrichment(BaseModel):
id: str
object: str = Field(default="webset_enrichment")
status: str = Field(default="pending") # Or use EnrichmentStatus
websetId: str
title: str
description: str
format: str = Field(default="text") # Or use EnrichmentFormat
options: List[Option]
instructions: str
metadata: Dict[str, Any] = Field(default_factory=dict)
createdAt: datetime
updatedAt: datetime
class Cadence(BaseModel):
cron: str
timezone: str = Field(default="Etc/UTC")
class BehaviorConfig(BaseModel):
query: Optional[str] = None
criteria: Optional[List[Criterion]] = None
entity: Optional[Entity] = None
count: Optional[int] = None
behavior: Optional[str] = Field(default=None)
class Behavior(BaseModel):
type: str = Field(default="search") # Or use MonitorBehaviorType
config: BehaviorConfig
class MonitorRun(BaseModel):
id: str
object: str = Field(default="monitor_run")
status: str = Field(default="created") # Or use MonitorRunStatus
monitorId: str
type: str = Field(default="search")
completedAt: Optional[datetime] = None
failedAt: Optional[datetime] = None
failedReason: Optional[str] = None
canceledAt: Optional[datetime] = None
createdAt: datetime
updatedAt: datetime
class Monitor(BaseModel):
id: str
object: str = Field(default="monitor")
status: str = Field(default="enabled") # Or use MonitorStatus
websetId: str
cadence: Cadence
behavior: Behavior
lastRun: Optional[MonitorRun] = None
nextRunAt: Optional[datetime] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
createdAt: datetime
updatedAt: datetime
class Webset(BaseModel):
id: str
object: str = Field(default="webset")
status: WebsetStatus
externalId: Optional[str] = None
title: Optional[str] = None
searches: List[WebsetSearch]
imports: List[Import]
enrichments: List[WebsetEnrichment]
monitors: List[Monitor]
streams: List[Any]
createdAt: datetime
updatedAt: datetime
metadata: Dict[str, Any] = Field(default_factory=dict)
class ListWebsets(BaseModel):
data: List[Webset]
hasMore: bool
nextCursor: Optional[str] = None

View File

@@ -114,7 +114,6 @@ class ExaWebsetWebhookBlock(Block):
def __init__(self):
super().__init__(
disabled=True,
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
description="Receive webhook notifications for Exa webset events",
categories={BlockCategory.INPUT},

View File

@@ -1,33 +1,7 @@
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional
from exa_py import Exa
from exa_py.websets.types import (
CreateCriterionParameters,
CreateEnrichmentParameters,
CreateWebsetParameters,
CreateWebsetParametersSearch,
ExcludeItem,
Format,
ImportItem,
ImportSource,
Option,
ScopeItem,
ScopeRelationship,
ScopeSourceType,
WebsetArticleEntity,
WebsetCompanyEntity,
WebsetCustomEntity,
WebsetPersonEntity,
WebsetResearchPaperEntity,
WebsetStatus,
)
from pydantic import Field
from typing import Any, Optional
from backend.sdk import (
APIKeyCredentials,
BaseModel,
Block,
BlockCategory,
BlockOutput,
@@ -38,69 +12,7 @@ from backend.sdk import (
)
from ._config import exa
class SearchEntityType(str, Enum):
COMPANY = "company"
PERSON = "person"
ARTICLE = "article"
RESEARCH_PAPER = "research_paper"
CUSTOM = "custom"
AUTO = "auto"
class SearchType(str, Enum):
IMPORT = "import"
WEBSET = "webset"
class EnrichmentFormat(str, Enum):
TEXT = "text"
DATE = "date"
NUMBER = "number"
OPTIONS = "options"
EMAIL = "email"
PHONE = "phone"
class Webset(BaseModel):
id: str
status: WebsetStatus | None = Field(..., title="WebsetStatus")
"""
The status of the webset
"""
external_id: Annotated[Optional[str], Field(alias="externalId")] = None
"""
The external identifier for the webset
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
searches: List[dict[str, Any]] | None = None
"""
The searches that have been performed on the webset.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
enrichments: List[dict[str, Any]] | None = None
"""
The Enrichments to apply to the Webset Items.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
monitors: List[dict[str, Any]] | None = None
"""
The Monitors for the Webset.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
metadata: Optional[Dict[str, Any]] = {}
"""
Set of key-value pairs you want to associate with this object.
"""
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
"""
The date and time the webset was created
"""
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
"""
The date and time the webset was last updated
"""
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
class ExaCreateWebsetBlock(Block):
@@ -108,121 +20,40 @@ class ExaCreateWebsetBlock(Block):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
# Search parameters (flattened)
search_query: str = SchemaField(
description="Your search query. Use this to describe what you are looking for. Any URL provided will be crawled and used as context for the search.",
placeholder="Marketing agencies based in the US, that focus on consumer products",
search: WebsetSearchConfig = SchemaField(
description="Initial search configuration for the Webset"
)
search_count: Optional[int] = SchemaField(
default=10,
description="Number of items the search will attempt to find. The actual number of items found may be less than this number depending on the search complexity.",
ge=1,
le=1000,
)
search_entity_type: SearchEntityType = SchemaField(
default=SearchEntityType.AUTO,
description="Entity type: 'company', 'person', 'article', 'research_paper', or 'custom'. If not provided, we automatically detect the entity from the query.",
advanced=True,
)
search_entity_description: Optional[str] = SchemaField(
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
default=None,
description="Description for custom entity type (required when search_entity_type is 'custom')",
description="Enrichments to apply to Webset items",
advanced=True,
)
# Search criteria (flattened)
search_criteria: list[str] = SchemaField(
default_factory=list,
description="List of criteria descriptions that every item will be evaluated against. If not provided, we automatically detect the criteria from the query.",
advanced=True,
)
# Search exclude sources (flattened)
search_exclude_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs (imports or websets) to exclude from search results",
advanced=True,
)
search_exclude_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to exclude sources ('import' or 'webset')",
advanced=True,
)
# Search scope sources (flattened)
search_scope_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs (imports or websets) to limit search scope to",
advanced=True,
)
search_scope_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to scope sources ('import' or 'webset')",
advanced=True,
)
search_scope_relationships: list[str] = SchemaField(
default_factory=list,
description="List of relationship definitions for hop searches (optional, one per scope source)",
advanced=True,
)
search_scope_relationship_limits: list[int] = SchemaField(
default_factory=list,
description="List of limits on the number of related entities to find (optional, one per scope relationship)",
advanced=True,
)
# Import parameters (flattened)
import_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs to import from",
advanced=True,
)
import_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to import sources ('import' or 'webset')",
advanced=True,
)
# Enrichment parameters (flattened)
enrichment_descriptions: list[str] = SchemaField(
default_factory=list,
description="List of enrichment task descriptions to perform on each webset item",
advanced=True,
)
enrichment_formats: list[EnrichmentFormat] = SchemaField(
default_factory=list,
description="List of formats for enrichment responses ('text', 'date', 'number', 'options', 'email', 'phone'). If not specified, we automatically select the best format.",
advanced=True,
)
enrichment_options: list[list[str]] = SchemaField(
default_factory=list,
description="List of option lists for enrichments with 'options' format. Each inner list contains the option labels.",
advanced=True,
)
enrichment_metadata: list[dict] = SchemaField(
default_factory=list,
description="List of metadata dictionaries for enrichments",
advanced=True,
)
# Webset metadata
external_id: Optional[str] = SchemaField(
default=None,
description="External identifier for the webset. You can use this to reference the webset by your own internal identifiers.",
description="External identifier for the webset",
placeholder="my-webset-123",
advanced=True,
)
metadata: Optional[dict] = SchemaField(
default_factory=dict,
default=None,
description="Key-value pairs to associate with this webset",
advanced=True,
)
class Output(BlockSchema):
webset: Webset = SchemaField(
webset_id: str = SchemaField(
description="The unique identifier for the created webset"
)
status: str = SchemaField(description="The status of the webset")
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
created_at: str = SchemaField(
description="The date and time the webset was created"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
@@ -236,171 +67,44 @@ class ExaCreateWebsetBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/websets/v0/websets"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
exa = Exa(credentials.api_key.get_secret_value())
# Build the payload
payload: dict[str, Any] = {
"search": input_data.search.model_dump(exclude_none=True),
}
# ------------------------------------------------------------
# Build entity (if explicitly provided)
# ------------------------------------------------------------
entity = None
if input_data.search_entity_type == SearchEntityType.COMPANY:
entity = WebsetCompanyEntity(type="company")
elif input_data.search_entity_type == SearchEntityType.PERSON:
entity = WebsetPersonEntity(type="person")
elif input_data.search_entity_type == SearchEntityType.ARTICLE:
entity = WebsetArticleEntity(type="article")
elif input_data.search_entity_type == SearchEntityType.RESEARCH_PAPER:
entity = WebsetResearchPaperEntity(type="research_paper")
elif (
input_data.search_entity_type == SearchEntityType.CUSTOM
and input_data.search_entity_description
):
entity = WebsetCustomEntity(
type="custom", description=input_data.search_entity_description
)
# Convert enrichments to API format
if input_data.enrichments:
enrichments_data = []
for enrichment in input_data.enrichments:
enrichments_data.append(enrichment.model_dump(exclude_none=True))
payload["enrichments"] = enrichments_data
# ------------------------------------------------------------
# Build criteria list
# ------------------------------------------------------------
criteria = None
if input_data.search_criteria:
criteria = [
CreateCriterionParameters(description=item)
for item in input_data.search_criteria
]
if input_data.external_id:
payload["externalId"] = input_data.external_id
# ------------------------------------------------------------
# Build exclude sources list
# ------------------------------------------------------------
exclude_items = None
if input_data.search_exclude_sources:
exclude_items = []
for idx, src_id in enumerate(input_data.search_exclude_sources):
src_type = None
if input_data.search_exclude_types and idx < len(
input_data.search_exclude_types
):
src_type = input_data.search_exclude_types[idx]
# Default to IMPORT if type missing
if src_type == SearchType.WEBSET:
source_enum = ImportSource.webset
else:
source_enum = ImportSource.import_
exclude_items.append(ExcludeItem(source=source_enum, id=src_id))
if input_data.metadata:
payload["metadata"] = input_data.metadata
# ------------------------------------------------------------
# Build scope list
# ------------------------------------------------------------
scope_items = None
if input_data.search_scope_sources:
scope_items = []
for idx, src_id in enumerate(input_data.search_scope_sources):
src_type = None
if input_data.search_scope_types and idx < len(
input_data.search_scope_types
):
src_type = input_data.search_scope_types[idx]
relationship = None
if input_data.search_scope_relationships and idx < len(
input_data.search_scope_relationships
):
rel_def = input_data.search_scope_relationships[idx]
lim = None
if input_data.search_scope_relationship_limits and idx < len(
input_data.search_scope_relationship_limits
):
lim = input_data.search_scope_relationship_limits[idx]
relationship = ScopeRelationship(definition=rel_def, limit=lim)
if src_type == SearchType.WEBSET:
src_enum = ScopeSourceType.webset
else:
src_enum = ScopeSourceType.import_
scope_items.append(
ScopeItem(source=src_enum, id=src_id, relationship=relationship)
)
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
# ------------------------------------------------------------
# Assemble search parameters (only if a query is provided)
# ------------------------------------------------------------
search_params = None
if input_data.search_query:
search_params = CreateWebsetParametersSearch(
query=input_data.search_query,
count=input_data.search_count,
entity=entity,
criteria=criteria,
exclude=exclude_items,
scope=scope_items,
)
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "created_at", data.get("createdAt", "")
# ------------------------------------------------------------
# Build imports list
# ------------------------------------------------------------
imports_params = None
if input_data.import_sources:
imports_params = []
for idx, src_id in enumerate(input_data.import_sources):
src_type = None
if input_data.import_types and idx < len(input_data.import_types):
src_type = input_data.import_types[idx]
if src_type == SearchType.WEBSET:
source_enum = ImportSource.webset
else:
source_enum = ImportSource.import_
imports_params.append(ImportItem(source=source_enum, id=src_id))
# ------------------------------------------------------------
# Build enrichment list
# ------------------------------------------------------------
enrichments_params = None
if input_data.enrichment_descriptions:
enrichments_params = []
for idx, desc in enumerate(input_data.enrichment_descriptions):
fmt = None
if input_data.enrichment_formats and idx < len(
input_data.enrichment_formats
):
fmt_enum = input_data.enrichment_formats[idx]
if fmt_enum is not None:
fmt = Format(
fmt_enum.value if isinstance(fmt_enum, Enum) else fmt_enum
)
options_list = None
if input_data.enrichment_options and idx < len(
input_data.enrichment_options
):
raw_opts = input_data.enrichment_options[idx]
if raw_opts:
options_list = [Option(label=o) for o in raw_opts]
metadata_obj = None
if input_data.enrichment_metadata and idx < len(
input_data.enrichment_metadata
):
metadata_obj = input_data.enrichment_metadata[idx]
enrichments_params.append(
CreateEnrichmentParameters(
description=desc,
format=fmt,
options=options_list,
metadata=metadata_obj,
)
)
# ------------------------------------------------------------
# Create the webset
# ------------------------------------------------------------
webset = exa.websets.create(
params=CreateWebsetParameters(
search=search_params,
imports=imports_params,
enrichments=enrichments_params,
external_id=input_data.external_id,
metadata=input_data.metadata,
)
)
# Use alias field names returned from Exa SDK so that nested models validate correctly
yield "webset", Webset.model_validate(webset.model_dump(by_alias=True))
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "created_at", ""
class ExaUpdateWebsetBlock(Block):
@@ -479,11 +183,6 @@ class ExaListWebsetsBlock(Block):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
trigger: Any | None = SchemaField(
default=None,
description="Trigger for the webset, value is ignored!",
advanced=False,
)
cursor: Optional[str] = SchemaField(
default=None,
description="Cursor for pagination through results",
@@ -498,9 +197,7 @@ class ExaListWebsetsBlock(Block):
)
class Output(BlockSchema):
websets: list[Webset] = SchemaField(
description="List of websets", default_factory=list
)
websets: list = SchemaField(description="List of websets", default_factory=list)
has_more: bool = SchemaField(
description="Whether there are more results to paginate through",
default=False,
@@ -558,6 +255,9 @@ class ExaGetWebsetBlock(Block):
description="The ID or external ID of the Webset to retrieve",
placeholder="webset-id-or-external-id",
)
expand_items: bool = SchemaField(
default=False, description="Include items in the response", advanced=True
)
class Output(BlockSchema):
webset_id: str = SchemaField(description="The unique identifier for the webset")
@@ -609,8 +309,12 @@ class ExaGetWebsetBlock(Block):
"x-api-key": credentials.api_key.get_secret_value(),
}
params = {}
if input_data.expand_items:
params["expand[]"] = "items"
try:
response = await Requests().get(url, headers=headers)
response = await Requests().get(url, headers=headers, params=params)
data = response.json()
yield "webset_id", data.get("id", "")

View File

@@ -29,8 +29,8 @@ class FirecrawlExtractBlock(Block):
prompt: str | None = SchemaField(
description="The prompt to use for the crawl", default=None, advanced=False
)
output_schema: dict | None = SchemaField(
description="A Json Schema describing the output structure if more rigid structure is desired.",
output_schema: str | None = SchemaField(
description="A more rigid structure if you already know the JSON layout.",
default=None,
)
enable_web_search: bool = SchemaField(
@@ -56,6 +56,7 @@ class FirecrawlExtractBlock(Block):
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
# Sync call
extract_result = app.extract(
urls=input_data.urls,
prompt=input_data.prompt,

View File

@@ -1,388 +0,0 @@
import logging
import re
from enum import Enum
from typing import Optional
from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
logger = logging.getLogger(__name__)
class CheckRunStatus(Enum):
QUEUED = "queued"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
class CheckRunConclusion(Enum):
SUCCESS = "success"
FAILURE = "failure"
NEUTRAL = "neutral"
CANCELLED = "cancelled"
SKIPPED = "skipped"
TIMED_OUT = "timed_out"
ACTION_REQUIRED = "action_required"
class GithubGetCIResultsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
target: str | int = SchemaField(
description="Commit SHA or PR number to get CI results for",
placeholder="abc123def or 123",
)
search_pattern: Optional[str] = SchemaField(
description="Optional regex pattern to search for in CI logs (e.g., error messages, file names)",
placeholder=".*error.*|.*warning.*",
default=None,
advanced=True,
)
check_name_filter: Optional[str] = SchemaField(
description="Optional filter for specific check names (supports wildcards)",
placeholder="*lint* or build-*",
default=None,
advanced=True,
)
class Output(BlockSchema):
class CheckRunItem(TypedDict, total=False):
id: int
name: str
status: str
conclusion: Optional[str]
started_at: Optional[str]
completed_at: Optional[str]
html_url: str
details_url: Optional[str]
output_title: Optional[str]
output_summary: Optional[str]
output_text: Optional[str]
annotations: list[dict]
class MatchedLine(TypedDict):
check_name: str
line_number: int
line: str
context: list[str]
check_run: CheckRunItem = SchemaField(
title="Check Run",
description="Individual CI check run with details",
)
check_runs: list[CheckRunItem] = SchemaField(
description="List of all CI check runs"
)
matched_line: MatchedLine = SchemaField(
title="Matched Line",
description="Line matching the search pattern with context",
)
matched_lines: list[MatchedLine] = SchemaField(
description="All lines matching the search pattern across all checks"
)
overall_status: str = SchemaField(
description="Overall CI status (pending, success, failure)"
)
overall_conclusion: str = SchemaField(
description="Overall CI conclusion if completed"
)
total_checks: int = SchemaField(description="Total number of CI checks")
passed_checks: int = SchemaField(description="Number of passed checks")
failed_checks: int = SchemaField(description="Number of failed checks")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="8ad9e103-78f2-4fdb-ba12-3571f2c95e98",
description="This block gets CI results for a commit or PR, with optional search for specific errors/warnings in logs.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetCIResultsBlock.Input,
output_schema=GithubGetCIResultsBlock.Output,
test_input={
"repo": "owner/repo",
"target": "abc123def456",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("overall_status", "completed"),
("overall_conclusion", "success"),
("total_checks", 1),
("passed_checks", 1),
("failed_checks", 0),
(
"check_runs",
[
{
"id": 123456,
"name": "build",
"status": "completed",
"conclusion": "success",
"started_at": "2024-01-01T00:00:00Z",
"completed_at": "2024-01-01T00:05:00Z",
"html_url": "https://github.com/owner/repo/runs/123456",
"details_url": None,
"output_title": "Build passed",
"output_summary": "All tests passed",
"output_text": "Build log output...",
"annotations": [],
}
],
),
],
test_mock={
"get_ci_results": lambda *args, **kwargs: {
"check_runs": [
{
"id": 123456,
"name": "build",
"status": "completed",
"conclusion": "success",
"started_at": "2024-01-01T00:00:00Z",
"completed_at": "2024-01-01T00:05:00Z",
"html_url": "https://github.com/owner/repo/runs/123456",
"details_url": None,
"output_title": "Build passed",
"output_summary": "All tests passed",
"output_text": "Build log output...",
"annotations": [],
}
],
"total_count": 1,
}
},
)
@staticmethod
async def get_commit_sha(api, repo: str, target: str | int) -> str:
"""Get commit SHA from either a commit SHA or PR URL."""
# If it's already a SHA, return it
if isinstance(target, str):
if re.match(r"^[0-9a-f]{6,40}$", target, re.IGNORECASE):
return target
# If it's a PR URL, get the head SHA
if isinstance(target, int):
pr_url = f"https://api.github.com/repos/{repo}/pulls/{target}"
response = await api.get(pr_url)
pr_data = response.json()
return pr_data["head"]["sha"]
raise ValueError("Target must be a commit SHA or PR URL")
@staticmethod
async def search_in_logs(
check_runs: list,
pattern: str,
) -> list[Output.MatchedLine]:
"""Search for pattern in check run logs."""
if not pattern:
return []
matched_lines = []
regex = re.compile(pattern, re.IGNORECASE | re.MULTILINE)
for check in check_runs:
output_text = check.get("output_text", "") or ""
if not output_text:
continue
lines = output_text.split("\n")
for i, line in enumerate(lines):
if regex.search(line):
# Get context (2 lines before and after)
start = max(0, i - 2)
end = min(len(lines), i + 3)
context = lines[start:end]
matched_lines.append(
{
"check_name": check["name"],
"line_number": i + 1,
"line": line,
"context": context,
}
)
return matched_lines
@staticmethod
async def get_ci_results(
credentials: GithubCredentials,
repo: str,
target: str | int,
search_pattern: Optional[str] = None,
check_name_filter: Optional[str] = None,
) -> dict:
api = get_api(credentials, convert_urls=False)
# Get the commit SHA
commit_sha = await GithubGetCIResultsBlock.get_commit_sha(api, repo, target)
# Get check runs for the commit
check_runs_url = (
f"https://api.github.com/repos/{repo}/commits/{commit_sha}/check-runs"
)
# Get all pages of check runs
all_check_runs = []
page = 1
per_page = 100
while True:
response = await api.get(
check_runs_url, params={"per_page": per_page, "page": page}
)
data = response.json()
check_runs = data.get("check_runs", [])
all_check_runs.extend(check_runs)
if len(check_runs) < per_page:
break
page += 1
# Filter by check name if specified
if check_name_filter:
import fnmatch
filtered_runs = []
for run in all_check_runs:
if fnmatch.fnmatch(run["name"].lower(), check_name_filter.lower()):
filtered_runs.append(run)
all_check_runs = filtered_runs
# Get check run details with logs
detailed_runs = []
for run in all_check_runs:
# Get detailed output including logs
if run.get("output", {}).get("text"):
# Already has output
detailed_run = {
"id": run["id"],
"name": run["name"],
"status": run["status"],
"conclusion": run.get("conclusion"),
"started_at": run.get("started_at"),
"completed_at": run.get("completed_at"),
"html_url": run["html_url"],
"details_url": run.get("details_url"),
"output_title": run.get("output", {}).get("title"),
"output_summary": run.get("output", {}).get("summary"),
"output_text": run.get("output", {}).get("text"),
"annotations": [],
}
else:
# Try to get logs from the check run
detailed_run = {
"id": run["id"],
"name": run["name"],
"status": run["status"],
"conclusion": run.get("conclusion"),
"started_at": run.get("started_at"),
"completed_at": run.get("completed_at"),
"html_url": run["html_url"],
"details_url": run.get("details_url"),
"output_title": run.get("output", {}).get("title"),
"output_summary": run.get("output", {}).get("summary"),
"output_text": None,
"annotations": [],
}
# Get annotations if available
if run.get("output", {}).get("annotations_count", 0) > 0:
annotations_url = f"https://api.github.com/repos/{repo}/check-runs/{run['id']}/annotations"
try:
ann_response = await api.get(annotations_url)
detailed_run["annotations"] = ann_response.json()
except Exception:
pass
detailed_runs.append(detailed_run)
return {
"check_runs": detailed_runs,
"total_count": len(detailed_runs),
}
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
target = int(input_data.target)
except ValueError:
target = input_data.target
result = await self.get_ci_results(
credentials,
input_data.repo,
target,
input_data.search_pattern,
input_data.check_name_filter,
)
check_runs = result["check_runs"]
# Calculate overall status
if not check_runs:
yield "overall_status", "no_checks"
yield "overall_conclusion", "no_checks"
else:
all_completed = all(run["status"] == "completed" for run in check_runs)
if all_completed:
yield "overall_status", "completed"
# Determine overall conclusion
has_failure = any(
run["conclusion"] in ["failure", "timed_out", "action_required"]
for run in check_runs
)
if has_failure:
yield "overall_conclusion", "failure"
else:
yield "overall_conclusion", "success"
else:
yield "overall_status", "pending"
yield "overall_conclusion", "pending"
# Count checks
total = len(check_runs)
passed = sum(1 for run in check_runs if run.get("conclusion") == "success")
failed = sum(
1 for run in check_runs if run.get("conclusion") in ["failure", "timed_out"]
)
yield "total_checks", total
yield "passed_checks", passed
yield "failed_checks", failed
# Output check runs
yield "check_runs", check_runs
# Search for patterns if specified
if input_data.search_pattern:
matched_lines = await self.search_in_logs(
check_runs, input_data.search_pattern
)
if matched_lines:
yield "matched_lines", matched_lines

View File

@@ -1,840 +0,0 @@
import logging
from enum import Enum
from typing import Any, List, Optional
from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
logger = logging.getLogger(__name__)
class ReviewEvent(Enum):
COMMENT = "COMMENT"
APPROVE = "APPROVE"
REQUEST_CHANGES = "REQUEST_CHANGES"
class GithubCreatePRReviewBlock(Block):
class Input(BlockSchema):
class ReviewComment(TypedDict, total=False):
path: str
position: Optional[int]
body: str
line: Optional[int] # Will be used as position if position not provided
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
body: str = SchemaField(
description="Body of the review comment",
placeholder="Enter your review comment",
)
event: ReviewEvent = SchemaField(
description="The review action to perform",
default=ReviewEvent.COMMENT,
)
create_as_draft: bool = SchemaField(
description="Create the review as a draft (pending) or post it immediately",
default=False,
advanced=False,
)
comments: Optional[List[ReviewComment]] = SchemaField(
description="Optional inline comments to add to specific files/lines. Note: Only path, body, and position are supported. Position is line number in diff from first @@ hunk.",
default=None,
advanced=True,
)
class Output(BlockSchema):
review_id: int = SchemaField(description="ID of the created review")
state: str = SchemaField(
description="State of the review (e.g., PENDING, COMMENTED, APPROVED, CHANGES_REQUESTED)"
)
html_url: str = SchemaField(description="URL of the created review")
error: str = SchemaField(
description="Error message if the review creation failed"
)
def __init__(self):
super().__init__(
id="84754b30-97d2-4c37-a3b8-eb39f268275b",
description="This block creates a review on a GitHub pull request with optional inline comments. You can create it as a draft or post immediately. Note: For inline comments, 'position' should be the line number in the diff (starting from the first @@ hunk header).",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreatePRReviewBlock.Input,
output_schema=GithubCreatePRReviewBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"body": "This looks good to me!",
"event": "APPROVE",
"create_as_draft": False,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("review_id", 123456),
("state", "APPROVED"),
(
"html_url",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
),
],
test_mock={
"create_review": lambda *args, **kwargs: (
123456,
"APPROVED",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
)
},
)
@staticmethod
async def create_review(
credentials: GithubCredentials,
repo: str,
pr_number: int,
body: str,
event: ReviewEvent,
create_as_draft: bool,
comments: Optional[List[Input.ReviewComment]] = None,
) -> tuple[int, str, str]:
api = get_api(credentials, convert_urls=False)
# GitHub API endpoint for creating reviews
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
# Get commit_id if we have comments
commit_id = None
if comments:
# Get PR details to get the head commit for inline comments
pr_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}"
pr_response = await api.get(pr_url)
pr_data = pr_response.json()
commit_id = pr_data["head"]["sha"]
# Prepare the request data
# If create_as_draft is True, omit the event field (creates a PENDING review)
# Otherwise, use the actual event value which will auto-submit the review
data: dict[str, Any] = {"body": body}
# Add commit_id if we have it
if commit_id:
data["commit_id"] = commit_id
# Add comments if provided
if comments:
# Process comments to ensure they have the required fields
processed_comments = []
for comment in comments:
comment_data: dict = {
"path": comment.get("path", ""),
"body": comment.get("body", ""),
}
# Add position or line
# Note: For review comments, only position is supported (not line/side)
if "position" in comment and comment.get("position") is not None:
comment_data["position"] = comment.get("position")
elif "line" in comment and comment.get("line") is not None:
# Note: Using line as position - may not work correctly
# Position should be calculated from the diff
comment_data["position"] = comment.get("line")
# Note: side, start_line, and start_side are NOT supported for review comments
# They are only for standalone PR comments
processed_comments.append(comment_data)
data["comments"] = processed_comments
if not create_as_draft:
# Only add event field if not creating a draft
data["event"] = event.value
# Create the review
response = await api.post(reviews_url, json=data)
review_data = response.json()
return review_data["id"], review_data["state"], review_data["html_url"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
review_id, state, html_url = await self.create_review(
credentials,
input_data.repo,
input_data.pr_number,
input_data.body,
input_data.event,
input_data.create_as_draft,
input_data.comments,
)
yield "review_id", review_id
yield "state", state
yield "html_url", html_url
except Exception as e:
yield "error", str(e)
class GithubListPRReviewsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
class Output(BlockSchema):
class ReviewItem(TypedDict):
id: int
user: str
state: str
body: str
html_url: str
review: ReviewItem = SchemaField(
title="Review",
description="Individual review with details",
)
reviews: list[ReviewItem] = SchemaField(
description="List of all reviews on the pull request"
)
error: str = SchemaField(description="Error message if listing reviews failed")
def __init__(self):
super().__init__(
id="f79bc6eb-33c0-4099-9c0f-d664ae1ba4d0",
description="This block lists all reviews for a specified GitHub pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListPRReviewsBlock.Input,
output_schema=GithubListPRReviewsBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"reviews",
[
{
"id": 123456,
"user": "reviewer1",
"state": "APPROVED",
"body": "Looks good!",
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
}
],
),
(
"review",
{
"id": 123456,
"user": "reviewer1",
"state": "APPROVED",
"body": "Looks good!",
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
},
),
],
test_mock={
"list_reviews": lambda *args, **kwargs: [
{
"id": 123456,
"user": "reviewer1",
"state": "APPROVED",
"body": "Looks good!",
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
}
]
},
)
@staticmethod
async def list_reviews(
credentials: GithubCredentials, repo: str, pr_number: int
) -> list[Output.ReviewItem]:
api = get_api(credentials, convert_urls=False)
# GitHub API endpoint for listing reviews
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
response = await api.get(reviews_url)
data = response.json()
reviews: list[GithubListPRReviewsBlock.Output.ReviewItem] = [
{
"id": review["id"],
"user": review["user"]["login"],
"state": review["state"],
"body": review.get("body", ""),
"html_url": review["html_url"],
}
for review in data
]
return reviews
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
reviews = await self.list_reviews(
credentials,
input_data.repo,
input_data.pr_number,
)
yield "reviews", reviews
for review in reviews:
yield "review", review
class GithubSubmitPendingReviewBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
review_id: int = SchemaField(
description="ID of the pending review to submit",
placeholder="123456",
)
event: ReviewEvent = SchemaField(
description="The review action to perform when submitting",
default=ReviewEvent.COMMENT,
)
class Output(BlockSchema):
state: str = SchemaField(description="State of the submitted review")
html_url: str = SchemaField(description="URL of the submitted review")
error: str = SchemaField(
description="Error message if the review submission failed"
)
def __init__(self):
super().__init__(
id="2e468217-7ca0-4201-9553-36e93eb9357a",
description="This block submits a pending (draft) review on a GitHub pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubSubmitPendingReviewBlock.Input,
output_schema=GithubSubmitPendingReviewBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"review_id": 123456,
"event": "APPROVE",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("state", "APPROVED"),
(
"html_url",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
),
],
test_mock={
"submit_review": lambda *args, **kwargs: (
"APPROVED",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
)
},
)
@staticmethod
async def submit_review(
credentials: GithubCredentials,
repo: str,
pr_number: int,
review_id: int,
event: ReviewEvent,
) -> tuple[str, str]:
api = get_api(credentials, convert_urls=False)
# GitHub API endpoint for submitting a review
submit_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/events"
data = {"event": event.value}
response = await api.post(submit_url, json=data)
review_data = response.json()
return review_data["state"], review_data["html_url"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
state, html_url = await self.submit_review(
credentials,
input_data.repo,
input_data.pr_number,
input_data.review_id,
input_data.event,
)
yield "state", state
yield "html_url", html_url
except Exception as e:
yield "error", str(e)
class GithubResolveReviewDiscussionBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
comment_id: int = SchemaField(
description="ID of the review comment to resolve/unresolve",
placeholder="123456",
)
resolve: bool = SchemaField(
description="Whether to resolve (true) or unresolve (false) the discussion",
default=True,
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the operation was successful")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="b4b8a38c-95ae-4c91-9ef8-c2cffaf2b5d1",
description="This block resolves or unresolves a review discussion thread on a GitHub pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubResolveReviewDiscussionBlock.Input,
output_schema=GithubResolveReviewDiscussionBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"comment_id": 123456,
"resolve": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"resolve_discussion": lambda *args, **kwargs: True},
)
@staticmethod
async def resolve_discussion(
credentials: GithubCredentials,
repo: str,
pr_number: int,
comment_id: int,
resolve: bool,
) -> bool:
api = get_api(credentials, convert_urls=False)
# Extract owner and repo name
parts = repo.split("/")
owner = parts[0]
repo_name = parts[1]
# GitHub GraphQL API is needed for resolving/unresolving discussions
# First, we need to get the node ID of the comment
graphql_url = "https://api.github.com/graphql"
# Query to get the review comment node ID
query = """
query($owner: String!, $repo: String!, $number: Int!) {
repository(owner: $owner, name: $repo) {
pullRequest(number: $number) {
reviewThreads(first: 100) {
nodes {
comments(first: 100) {
nodes {
databaseId
id
}
}
id
isResolved
}
}
}
}
}
"""
variables = {"owner": owner, "repo": repo_name, "number": pr_number}
response = await api.post(
graphql_url, json={"query": query, "variables": variables}
)
data = response.json()
# Find the thread containing our comment
thread_id = None
for thread in data["data"]["repository"]["pullRequest"]["reviewThreads"][
"nodes"
]:
for comment in thread["comments"]["nodes"]:
if comment["databaseId"] == comment_id:
thread_id = thread["id"]
break
if thread_id:
break
if not thread_id:
raise ValueError(f"Comment {comment_id} not found in pull request")
# Now resolve or unresolve the thread
# GitHub's GraphQL API has separate mutations for resolve and unresolve
if resolve:
mutation = """
mutation($threadId: ID!) {
resolveReviewThread(input: {threadId: $threadId}) {
thread {
isResolved
}
}
}
"""
else:
mutation = """
mutation($threadId: ID!) {
unresolveReviewThread(input: {threadId: $threadId}) {
thread {
isResolved
}
}
}
"""
mutation_variables = {"threadId": thread_id}
response = await api.post(
graphql_url, json={"query": mutation, "variables": mutation_variables}
)
result = response.json()
if "errors" in result:
raise Exception(f"GraphQL error: {result['errors']}")
return True
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
success = await self.resolve_discussion(
credentials,
input_data.repo,
input_data.pr_number,
input_data.comment_id,
input_data.resolve,
)
yield "success", success
except Exception as e:
yield "success", False
yield "error", str(e)
class GithubGetPRReviewCommentsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
review_id: Optional[int] = SchemaField(
description="ID of a specific review to get comments from (optional)",
placeholder="123456",
default=None,
advanced=True,
)
class Output(BlockSchema):
class CommentItem(TypedDict):
id: int
user: str
body: str
path: str
line: int
side: str
created_at: str
updated_at: str
in_reply_to_id: Optional[int]
html_url: str
comment: CommentItem = SchemaField(
title="Comment",
description="Individual review comment with details",
)
comments: list[CommentItem] = SchemaField(
description="List of all review comments on the pull request"
)
error: str = SchemaField(description="Error message if getting comments failed")
def __init__(self):
super().__init__(
id="1d34db7f-10c1-45c1-9d43-749f743c8bd4",
description="This block gets all review comments from a GitHub pull request or from a specific review.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetPRReviewCommentsBlock.Input,
output_schema=GithubGetPRReviewCommentsBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"comments",
[
{
"id": 123456,
"user": "reviewer1",
"body": "This needs improvement",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"in_reply_to_id": None,
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
}
],
),
(
"comment",
{
"id": 123456,
"user": "reviewer1",
"body": "This needs improvement",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"in_reply_to_id": None,
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
},
),
],
test_mock={
"get_comments": lambda *args, **kwargs: [
{
"id": 123456,
"user": "reviewer1",
"body": "This needs improvement",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"in_reply_to_id": None,
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
}
]
},
)
@staticmethod
async def get_comments(
credentials: GithubCredentials,
repo: str,
pr_number: int,
review_id: Optional[int] = None,
) -> list[Output.CommentItem]:
api = get_api(credentials, convert_urls=False)
# Determine the endpoint based on whether we want comments from a specific review
if review_id:
# Get comments from a specific review
comments_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/comments"
else:
# Get all review comments on the PR
comments_url = (
f"https://api.github.com/repos/{repo}/pulls/{pr_number}/comments"
)
response = await api.get(comments_url)
data = response.json()
comments: list[GithubGetPRReviewCommentsBlock.Output.CommentItem] = [
{
"id": comment["id"],
"user": comment["user"]["login"],
"body": comment["body"],
"path": comment.get("path", ""),
"line": comment.get("line", 0),
"side": comment.get("side", ""),
"created_at": comment["created_at"],
"updated_at": comment["updated_at"],
"in_reply_to_id": comment.get("in_reply_to_id"),
"html_url": comment["html_url"],
}
for comment in data
]
return comments
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
comments = await self.get_comments(
credentials,
input_data.repo,
input_data.pr_number,
input_data.review_id,
)
yield "comments", comments
for comment in comments:
yield "comment", comment
except Exception as e:
yield "error", str(e)
class GithubCreateCommentObjectBlock(Block):
class Input(BlockSchema):
path: str = SchemaField(
description="The file path to comment on",
placeholder="src/main.py",
)
body: str = SchemaField(
description="The comment text",
placeholder="Please fix this issue",
)
position: Optional[int] = SchemaField(
description="Position in the diff (line number from first @@ hunk). Use this OR line.",
placeholder="6",
default=None,
advanced=True,
)
line: Optional[int] = SchemaField(
description="Line number in the file (will be used as position if position not provided)",
placeholder="42",
default=None,
advanced=True,
)
side: Optional[str] = SchemaField(
description="Side of the diff to comment on (NOTE: Only for standalone comments, not review comments)",
default="RIGHT",
advanced=True,
)
start_line: Optional[int] = SchemaField(
description="Start line for multi-line comments (NOTE: Only for standalone comments, not review comments)",
default=None,
advanced=True,
)
start_side: Optional[str] = SchemaField(
description="Side for the start of multi-line comments (NOTE: Only for standalone comments, not review comments)",
default=None,
advanced=True,
)
class Output(BlockSchema):
comment_object: dict = SchemaField(
description="The comment object formatted for GitHub API"
)
def __init__(self):
super().__init__(
id="b7d5e4f2-8c3a-4e6b-9f1d-7a8b9c5e4d3f",
description="Creates a comment object for use with GitHub blocks. Note: For review comments, only path, body, and position are used. Side fields are only for standalone PR comments.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateCommentObjectBlock.Input,
output_schema=GithubCreateCommentObjectBlock.Output,
test_input={
"path": "src/main.py",
"body": "Please fix this issue",
"position": 6,
},
test_output=[
(
"comment_object",
{
"path": "src/main.py",
"body": "Please fix this issue",
"position": 6,
},
),
],
)
async def run(
self,
input_data: Input,
**kwargs,
) -> BlockOutput:
# Build the comment object
comment_obj: dict = {
"path": input_data.path,
"body": input_data.body,
}
# Add position or line
if input_data.position is not None:
comment_obj["position"] = input_data.position
elif input_data.line is not None:
# Note: line will be used as position, which may not be accurate
# Position should be calculated from the diff
comment_obj["position"] = input_data.line
# Add optional fields only if they differ from defaults or are explicitly provided
if input_data.side and input_data.side != "RIGHT":
comment_obj["side"] = input_data.side
if input_data.start_line is not None:
comment_obj["start_line"] = input_data.start_line
if input_data.start_side:
comment_obj["start_side"] = input_data.start_side
yield "comment_object", comment_obj

View File

@@ -21,8 +21,6 @@ from ._auth import (
GoogleCredentialsInput,
)
settings = Settings()
class CalendarEvent(BaseModel):
"""Structured representation of a Google Calendar event."""
@@ -223,8 +221,8 @@ class GoogleCalendarReadEventsBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=settings.secrets.google_client_id,
client_secret=settings.secrets.google_client_secret,
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
scopes=credentials.scopes,
)
return build("calendar", "v3", credentials=creds)
@@ -571,8 +569,8 @@ class GoogleCalendarCreateEventBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=settings.secrets.google_client_id,
client_secret=settings.secrets.google_client_secret,
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
scopes=credentials.scopes,
)
return build("calendar", "v3", credentials=creds)

File diff suppressed because it is too large Load Diff

View File

@@ -37,7 +37,6 @@ LLMProviderName = Literal[
ProviderName.OPENAI,
ProviderName.OPEN_ROUTER,
ProviderName.LLAMA_API,
ProviderName.V0,
]
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
@@ -82,11 +81,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
O3 = "o3-2025-04-16"
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5 = "gpt-5-2025-08-07"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
@@ -94,7 +88,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
@@ -122,8 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
@@ -156,10 +147,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
# v0 by Vercel models
V0_1_5_MD = "v0-1.5-md"
V0_1_5_LG = "v0-1.5-lg"
V0_1_0_MD = "v0-1.0-md"
@property
def metadata(self) -> ModelMetadata:
@@ -184,11 +171,6 @@ MODEL_METADATA = {
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT41_MINI: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT4O_MINI: ModelMetadata(
@@ -200,9 +182,6 @@ MODEL_METADATA = {
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 8192
), # claude-4-opus-20250514
@@ -267,8 +246,6 @@ MODEL_METADATA = {
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
"open_router", 12288, 12288
),
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata("open_router", 131072, 131072),
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata("open_router", 131072, 32768),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 300000, 5120),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 128000, 5120),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
@@ -285,10 +262,6 @@ MODEL_METADATA = {
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata("llama_api", 128000, 4028),
# v0 by Vercel models
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000),
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000),
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
}
for model in LlmModel:
@@ -502,7 +475,6 @@ async def llm_call(
messages=messages,
max_tokens=max_tokens,
tools=an_tools,
timeout=600,
)
if not resp.content:
@@ -685,11 +657,7 @@ async def llm_call(
client = openai.OpenAI(
base_url="https://api.aimlapi.com/v2",
api_key=credentials.api_key.get_secret_value(),
default_headers={
"X-Project": "AutoGPT",
"X-Title": "AutoGPT",
"HTTP-Referer": "https://github.com/Significant-Gravitas/AutoGPT",
},
default_headers={"X-Project": "AutoGPT"},
)
completion = client.chat.completions.create(
@@ -709,42 +677,6 @@ async def llm_call(
),
reasoning=None,
)
elif provider == "v0":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url="https://api.v0.dev/v1",
api_key=credentials.api_key.get_secret_value(),
)
response_format = None
if json_format:
response_format = {"type": "json_object"}
parallel_tool_calls_param = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=parallel_tool_calls_param,
)
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@@ -1,13 +1,22 @@
import logging
from typing import Any, Literal
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
StorageScope = Literal["within_agent", "across_agents"]
@@ -79,7 +88,7 @@ class PersistInformationBlock(Block):
async def _store_data(
self, user_id: str, node_exec_id: str, key: str, data: Any
) -> Any | None:
return await get_database_manager_async_client().set_execution_kv_data(
return await get_database_manager_client().set_execution_kv_data(
user_id=user_id,
node_exec_id=node_exec_id,
key=key,
@@ -140,7 +149,7 @@ class RetrieveInformationBlock(Block):
yield "value", input_data.default_value
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
return await get_database_manager_async_client().get_execution_kv_data(
return await get_database_manager_client().get_execution_kv_data(
user_id=user_id,
key=key,
)

View File

@@ -3,7 +3,8 @@ from typing import List
from backend.data.block import BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.settings import BehaveAs, Settings
from backend.util import settings
from backend.util.settings import BehaveAs
from ._api import (
TEST_CREDENTIALS,
@@ -15,8 +16,6 @@ from ._api import (
)
from .base import Slant3DBlockBase
settings = Settings()
class Slant3DCreateOrderBlock(Slant3DBlockBase):
"""Block for creating new orders"""
@@ -281,7 +280,7 @@ class Slant3DGetOrdersBlock(Slant3DBlockBase):
input_schema=self.Input,
output_schema=self.Output,
# This block is disabled for cloud hosted because it allows access to all orders for the account
disabled=settings.config.behave_as == BehaveAs.CLOUD,
disabled=settings.Settings().config.behave_as == BehaveAs.CLOUD,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_credentials=TEST_CREDENTIALS,
test_output=[

View File

@@ -9,7 +9,8 @@ from backend.data.block import (
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.util.settings import AppEnvironment, BehaveAs, Settings
from backend.util import settings
from backend.util.settings import AppEnvironment, BehaveAs
from ._api import (
TEST_CREDENTIALS,
@@ -18,8 +19,6 @@ from ._api import (
Slant3DCredentialsInput,
)
settings = Settings()
class Slant3DTriggerBase:
"""Base class for Slant3D webhook triggers"""
@@ -77,8 +76,8 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
),
# All webhooks are currently subscribed to for all orders. This works for self hosted, but not for cloud hosted prod
disabled=(
settings.config.behave_as == BehaveAs.CLOUD
and settings.config.app_env != AppEnvironment.LOCAL
settings.Settings().config.behave_as == BehaveAs.CLOUD
and settings.Settings().config.app_env != AppEnvironment.LOCAL
),
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=self.Input,

View File

@@ -3,6 +3,8 @@ import re
from collections import Counter
from typing import TYPE_CHECKING, Any
from autogpt_libs.utils.cache import thread_cached
import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import (
@@ -15,7 +17,6 @@ from backend.data.block import (
)
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json
from backend.util.clients import get_database_manager_async_client
if TYPE_CHECKING:
from backend.data.graph import Link, Node
@@ -23,6 +24,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
@@ -291,32 +300,9 @@ class SmartDecisionMakerBlock(Block):
for link in links:
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
# These are fields that get merged by the executor into their base field
if (
"_#_" in link.sink_name
or "_$_" in link.sink_name
or "_@_" in link.sink_name
):
# For dynamic fields, provide a generic string schema
# The executor will handle merging these into the appropriate structure
properties[sink_name] = {
"type": "string",
"description": f"Dynamic value for {link.sink_name}",
}
else:
# For regular fields, use the block's schema
try:
properties[sink_name] = sink_block_input_schema.get_field_schema(
link.sink_name
)
except (KeyError, AttributeError):
# If the field doesn't exist in the schema, provide a generic schema
properties[sink_name] = {
"type": "string",
"description": f"Value for {link.sink_name}",
}
properties[sink_name] = sink_block_input_schema.get_field_schema(
link.sink_name
)
tool_function["parameters"] = {
**block.input_schema.jsonschema(),
@@ -347,7 +333,7 @@ class SmartDecisionMakerBlock(Block):
if not graph_id or not graph_version:
raise ValueError("Graph ID or Graph Version not found in sink node.")
db_client = get_database_manager_async_client()
db_client = get_database_manager_client()
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
if not sink_graph_meta:
raise ValueError(
@@ -407,7 +393,7 @@ class SmartDecisionMakerBlock(Block):
ValueError: If no tool links are found for the specified node_id, or if a sink node
or its metadata cannot be found.
"""
db_client = get_database_manager_async_client()
db_client = get_database_manager_client()
tools = [
(link, node)
for link, node in await db_client.get_connected_output_nodes(node_id)
@@ -501,6 +487,10 @@ class SmartDecisionMakerBlock(Block):
}
)
prompt.extend(tool_output)
if input_data.multiple_tool_calls:
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
else:
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
values = input_data.prompt_values
if values:
@@ -539,6 +529,15 @@ class SmartDecisionMakerBlock(Block):
)
)
# Add reasoning to conversation history if available
if response.reasoning:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
prompt.append(response.raw_response)
yield "conversations", prompt
if not response.tool_calls:
yield "finished", response.response
return
@@ -572,12 +571,3 @@ class SmartDecisionMakerBlock(Block):
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
else:
yield f"tools_^_{tool_name}_~_{arg_name}", None
# Add reasoning to conversation history if available
if response.reasoning:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
prompt.append(response.raw_response)
yield "conversations", prompt

View File

@@ -1,8 +1,9 @@
import logging
import pytest
from prisma.models import User
from backend.data.model import ProviderName, User
from backend.data.model import ProviderName
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user

View File

@@ -1,130 +0,0 @@
from unittest.mock import Mock
import pytest
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@pytest.mark.asyncio
async def test_smart_decision_maker_handles_dynamic_dict_fields():
"""Test Smart Decision Maker can handle dynamic dictionary fields (_#_) for any block"""
# Create a mock node for CreateDictionaryBlock
mock_node = Mock()
mock_node.block = CreateDictionaryBlock()
mock_node.block_id = CreateDictionaryBlock().id
mock_node.input_default = {}
# Create mock links with dynamic dictionary fields
mock_links = [
Mock(
source_name="tools_^_create_dict_~_name",
sink_name="values_#_name", # Dynamic dict field
sink_id="dict_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_age",
sink_name="values_#_age", # Dynamic dict field
sink_id="dict_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_city",
sink_name="values_#_city", # Dynamic dict field
sink_id="dict_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, mock_links # type: ignore
)
# Verify the signature was created successfully
assert signature["type"] == "function"
assert "parameters" in signature["function"]
assert "properties" in signature["function"]["parameters"]
# Check that dynamic fields are handled
properties = signature["function"]["parameters"]["properties"]
assert len(properties) == 3 # Should have all three fields
# Each dynamic field should have proper schema
for prop_value in properties.values():
assert "type" in prop_value
assert prop_value["type"] == "string" # Dynamic fields get string type
assert "description" in prop_value
assert "Dynamic value for" in prop_value["description"]
@pytest.mark.asyncio
async def test_smart_decision_maker_handles_dynamic_list_fields():
"""Test Smart Decision Maker can handle dynamic list fields (_$_) for any block"""
# Create a mock node for AddToListBlock
mock_node = Mock()
mock_node.block = AddToListBlock()
mock_node.block_id = AddToListBlock().id
mock_node.input_default = {}
# Create mock links with dynamic list fields
mock_links = [
Mock(
source_name="tools_^_add_to_list_~_0",
sink_name="entries_$_0", # Dynamic list field
sink_id="list_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_add_to_list_~_1",
sink_name="entries_$_1", # Dynamic list field
sink_id="list_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, mock_links # type: ignore
)
# Verify dynamic list fields are handled properly
assert signature["type"] == "function"
properties = signature["function"]["parameters"]["properties"]
assert len(properties) == 2 # Should have both list items
# Each dynamic field should have proper schema
for prop_value in properties.values():
assert prop_value["type"] == "string"
assert "Dynamic value for" in prop_value["description"]
@pytest.mark.asyncio
async def test_create_dict_block_with_dynamic_values():
"""Test CreateDictionaryBlock processes dynamic values correctly"""
block = CreateDictionaryBlock()
# Simulate what happens when executor merges dynamic fields
# The executor merges values_#_* fields into the values dict
input_data = block.input_schema(
values={
"existing": "value",
"name": "Alice", # This would come from values_#_name
"age": 25, # This would come from values_#_age
}
)
# Run the block
result = {}
async for output_name, output_value in block.run(input_data):
result[output_name] = output_value
# Check the result
assert "dictionary" in result
assert result["dictionary"]["existing"] == "value"
assert result["dictionary"]["name"] == "Alice"
assert result["dictionary"]["age"] == 25

View File

@@ -1,78 +1,19 @@
import asyncio
import time
from datetime import datetime, timedelta
from typing import Any, Literal, Union
from zoneinfo import ZoneInfo
from pydantic import BaseModel
from typing import Any, Union
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
# Shared timezone literal type for all time/date blocks
TimezoneLiteral = Literal[
"UTC", # UTC±00:00
"Pacific/Honolulu", # UTC-10:00
"America/Anchorage", # UTC-09:00 (Alaska)
"America/Los_Angeles", # UTC-08:00 (Pacific)
"America/Denver", # UTC-07:00 (Mountain)
"America/Chicago", # UTC-06:00 (Central)
"America/New_York", # UTC-05:00 (Eastern)
"America/Caracas", # UTC-04:00
"America/Sao_Paulo", # UTC-03:00
"America/St_Johns", # UTC-02:30 (Newfoundland)
"Atlantic/South_Georgia", # UTC-02:00
"Atlantic/Azores", # UTC-01:00
"Europe/London", # UTC+00:00 (GMT/BST)
"Europe/Paris", # UTC+01:00 (CET)
"Europe/Athens", # UTC+02:00 (EET)
"Europe/Moscow", # UTC+03:00
"Asia/Tehran", # UTC+03:30 (Iran)
"Asia/Dubai", # UTC+04:00
"Asia/Kabul", # UTC+04:30 (Afghanistan)
"Asia/Karachi", # UTC+05:00 (Pakistan)
"Asia/Kolkata", # UTC+05:30 (India)
"Asia/Kathmandu", # UTC+05:45 (Nepal)
"Asia/Dhaka", # UTC+06:00 (Bangladesh)
"Asia/Yangon", # UTC+06:30 (Myanmar)
"Asia/Bangkok", # UTC+07:00
"Asia/Shanghai", # UTC+08:00 (China)
"Australia/Eucla", # UTC+08:45
"Asia/Tokyo", # UTC+09:00 (Japan)
"Australia/Adelaide", # UTC+09:30
"Australia/Sydney", # UTC+10:00
"Australia/Lord_Howe", # UTC+10:30
"Pacific/Noumea", # UTC+11:00
"Pacific/Auckland", # UTC+12:00 (New Zealand)
"Pacific/Chatham", # UTC+12:45
"Pacific/Tongatapu", # UTC+13:00
"Pacific/Kiritimati", # UTC+14:00
"Etc/GMT-12", # UTC+12:00
"Etc/GMT+12", # UTC-12:00
]
class TimeStrftimeFormat(BaseModel):
discriminator: Literal["strftime"]
format: str = "%H:%M:%S"
timezone: TimezoneLiteral = "UTC"
class TimeISO8601Format(BaseModel):
discriminator: Literal["iso8601"]
timezone: TimezoneLiteral = "UTC"
include_microseconds: bool = False
class GetCurrentTimeBlock(Block):
class Input(BlockSchema):
trigger: str = SchemaField(
description="Trigger any data to output the current time"
)
format_type: Union[TimeStrftimeFormat, TimeISO8601Format] = SchemaField(
discriminator="discriminator",
description="Format type for time output (strftime with custom format or ISO 8601)",
default=TimeStrftimeFormat(discriminator="strftime"),
format: str = SchemaField(
description="Format of the time to output", default="%H:%M:%S"
)
class Output(BlockSchema):
@@ -89,65 +30,19 @@ class GetCurrentTimeBlock(Block):
output_schema=GetCurrentTimeBlock.Output,
test_input=[
{"trigger": "Hello"},
{
"trigger": "Hello",
"format_type": {
"discriminator": "strftime",
"format": "%H:%M",
},
},
{
"trigger": "Hello",
"format_type": {
"discriminator": "iso8601",
"timezone": "UTC",
"include_microseconds": False,
},
},
{"trigger": "Hello", "format": "%H:%M"},
],
test_output=[
("time", lambda _: time.strftime("%H:%M:%S")),
("time", lambda _: time.strftime("%H:%M")),
(
"time",
lambda t: "T" in t and ("+" in t or "Z" in t),
), # Check for ISO format with timezone
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if isinstance(input_data.format_type, TimeISO8601Format):
# ISO 8601 format for time only (extract time portion from full ISO datetime)
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
# Get the full ISO format and extract just the time portion with timezone
if input_data.format_type.include_microseconds:
full_iso = dt.isoformat()
else:
full_iso = dt.isoformat(timespec="seconds")
# Extract time portion (everything after 'T')
current_time = full_iso.split("T")[1] if "T" in full_iso else full_iso
current_time = f"T{current_time}" # Add T prefix for ISO 8601 time format
else: # TimeStrftimeFormat
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
current_time = dt.strftime(input_data.format_type.format)
current_time = time.strftime(input_data.format)
yield "time", current_time
class DateStrftimeFormat(BaseModel):
discriminator: Literal["strftime"]
format: str = "%Y-%m-%d"
timezone: TimezoneLiteral = "UTC"
class DateISO8601Format(BaseModel):
discriminator: Literal["iso8601"]
timezone: TimezoneLiteral = "UTC"
class GetCurrentDateBlock(Block):
class Input(BlockSchema):
trigger: str = SchemaField(
@@ -158,10 +53,8 @@ class GetCurrentDateBlock(Block):
description="Offset in days from the current date",
default=0,
)
format_type: Union[DateStrftimeFormat, DateISO8601Format] = SchemaField(
discriminator="discriminator",
description="Format type for date output (strftime with custom format or ISO 8601)",
default=DateStrftimeFormat(discriminator="strftime"),
format: str = SchemaField(
description="Format of the date to output", default="%Y-%m-%d"
)
class Output(BlockSchema):
@@ -178,22 +71,7 @@ class GetCurrentDateBlock(Block):
output_schema=GetCurrentDateBlock.Output,
test_input=[
{"trigger": "Hello", "offset": "7"},
{
"trigger": "Hello",
"offset": "7",
"format_type": {
"discriminator": "strftime",
"format": "%m/%d/%Y",
},
},
{
"trigger": "Hello",
"offset": "0",
"format_type": {
"discriminator": "iso8601",
"timezone": "UTC",
},
},
{"trigger": "Hello", "offset": "7", "format": "%m/%d/%Y"},
],
test_output=[
(
@@ -207,12 +85,6 @@ class GetCurrentDateBlock(Block):
< timedelta(days=8),
# 7 days difference + 1 day error margin.
),
(
"date",
lambda t: len(t) == 10
and t[4] == "-"
and t[7] == "-", # ISO date format YYYY-MM-DD
),
],
)
@@ -221,31 +93,8 @@ class GetCurrentDateBlock(Block):
offset = int(input_data.offset)
except ValueError:
offset = 0
if isinstance(input_data.format_type, DateISO8601Format):
# ISO 8601 format for date only (YYYY-MM-DD)
tz = ZoneInfo(input_data.format_type.timezone)
current_date = datetime.now(tz=tz) - timedelta(days=offset)
# ISO 8601 date format is YYYY-MM-DD
date_str = current_date.date().isoformat()
else: # DateStrftimeFormat
tz = ZoneInfo(input_data.format_type.timezone)
current_date = datetime.now(tz=tz) - timedelta(days=offset)
date_str = current_date.strftime(input_data.format_type.format)
yield "date", date_str
class StrftimeFormat(BaseModel):
discriminator: Literal["strftime"]
format: str = "%Y-%m-%d %H:%M:%S"
timezone: TimezoneLiteral = "UTC"
class ISO8601Format(BaseModel):
discriminator: Literal["iso8601"]
timezone: TimezoneLiteral = "UTC"
include_microseconds: bool = False
current_date = datetime.now() - timedelta(days=offset)
yield "date", current_date.strftime(input_data.format)
class GetCurrentDateAndTimeBlock(Block):
@@ -253,10 +102,9 @@ class GetCurrentDateAndTimeBlock(Block):
trigger: str = SchemaField(
description="Trigger any data to output the current date and time"
)
format_type: Union[StrftimeFormat, ISO8601Format] = SchemaField(
discriminator="discriminator",
description="Format type for date and time output (strftime with custom format or ISO 8601/RFC 3339)",
default=StrftimeFormat(discriminator="strftime"),
format: str = SchemaField(
description="Format of the date and time to output",
default="%Y-%m-%d %H:%M:%S",
)
class Output(BlockSchema):
@@ -273,63 +121,20 @@ class GetCurrentDateAndTimeBlock(Block):
output_schema=GetCurrentDateAndTimeBlock.Output,
test_input=[
{"trigger": "Hello"},
{
"trigger": "Hello",
"format_type": {
"discriminator": "strftime",
"format": "%Y/%m/%d",
},
},
{
"trigger": "Hello",
"format_type": {
"discriminator": "iso8601",
"timezone": "UTC",
"include_microseconds": False,
},
},
],
test_output=[
(
"date_time",
lambda t: abs(
datetime.now(tz=ZoneInfo("UTC"))
- datetime.strptime(t + "+00:00", "%Y-%m-%d %H:%M:%S%z")
datetime.now() - datetime.strptime(t, "%Y-%m-%d %H:%M:%S")
)
< timedelta(seconds=10), # 10 seconds error margin.
),
(
"date_time",
lambda t: abs(
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
)
< timedelta(days=1), # Date format only, no time component
),
(
"date_time",
lambda t: abs(
datetime.now(tz=ZoneInfo("UTC")) - datetime.fromisoformat(t)
)
< timedelta(seconds=10), # 10 seconds error margin for ISO format.
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if isinstance(input_data.format_type, ISO8601Format):
# ISO 8601 format with specified timezone (also RFC3339-compliant)
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
# Format with or without microseconds
if input_data.format_type.include_microseconds:
current_date_time = dt.isoformat()
else:
current_date_time = dt.isoformat(timespec="seconds")
else: # StrftimeFormat
tz = ZoneInfo(input_data.format_type.timezone)
dt = datetime.now(tz=tz)
current_date_time = dt.strftime(input_data.format_type.format)
current_date_time = time.strftime(input_data.format)
yield "date_time", current_date_time

View File

@@ -5,12 +5,6 @@ from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
from backend.blocks.apollo.organization import SearchOrganizationsBlock
from backend.blocks.apollo.people import SearchPeopleBlock
from backend.blocks.apollo.person import GetPersonDetailBlock
from backend.blocks.enrichlayer.linkedin import (
GetLinkedinProfileBlock,
GetLinkedinProfilePictureBlock,
LinkedinPersonLookupBlock,
LinkedinRoleLookupBlock,
)
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
@@ -36,7 +30,6 @@ from backend.integrations.credentials_store import (
anthropic_credentials,
apollo_credentials,
did_credentials,
enrichlayer_credentials,
groq_credentials,
ideogram_credentials,
jina_credentials,
@@ -46,7 +39,6 @@ from backend.integrations.credentials_store import (
replicate_credentials,
revid_credentials,
unreal_credentials,
v0_credentials,
)
# =============== Configure the cost for each LLM Model call =============== #
@@ -56,18 +48,12 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3_MINI: 2, # $1.10 / $4.40
LlmModel.O1: 16, # $15 / $60
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5: 2,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 2,
LlmModel.GPT41: 2,
LlmModel.GPT41_MINI: 1,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_3_7_SONNET: 5,
@@ -90,8 +76,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.OLLAMA_LLAMA3_405B: 1,
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.OPENAI_GPT_OSS_120B: 1,
LlmModel.OPENAI_GPT_OSS_20B: 1,
LlmModel.GEMINI_FLASH_1_5: 1,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.MISTRAL_NEMO: 1,
@@ -123,10 +107,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.DEEPSEEK_R1_0528: 1,
# v0 by Vercel models
LlmModel.V0_1_5_MD: 1,
LlmModel.V0_1_5_LG: 2,
LlmModel.V0_1_0_MD: 1,
}
for model in LlmModel:
@@ -216,23 +196,6 @@ LLM_COST = (
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "llama_api"
]
# v0 by Vercel Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": v0_credentials.id,
"provider": v0_credentials.provider,
"type": v0_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "v0"
]
# AI/ML Api Models
+ [
BlockCost(
@@ -405,54 +368,6 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
},
)
],
GetLinkedinProfileBlock: [
BlockCost(
cost_amount=1,
cost_filter={
"credentials": {
"id": enrichlayer_credentials.id,
"provider": enrichlayer_credentials.provider,
"type": enrichlayer_credentials.type,
}
},
)
],
LinkedinPersonLookupBlock: [
BlockCost(
cost_amount=2,
cost_filter={
"credentials": {
"id": enrichlayer_credentials.id,
"provider": enrichlayer_credentials.provider,
"type": enrichlayer_credentials.type,
}
},
)
],
LinkedinRoleLookupBlock: [
BlockCost(
cost_amount=3,
cost_filter={
"credentials": {
"id": enrichlayer_credentials.id,
"provider": enrichlayer_credentials.provider,
"type": enrichlayer_credentials.type,
}
},
)
],
GetLinkedinProfilePictureBlock: [
BlockCost(
cost_amount=3,
cost_filter={
"credentials": {
"id": enrichlayer_credentials.id,
"provider": enrichlayer_credentials.provider,
"type": enrichlayer_credentials.type,
}
},
)
],
SmartDecisionMakerBlock: LLM_COST,
SearchOrganizationsBlock: [
BlockCost(

View File

@@ -34,10 +34,10 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.server.model import Pagination
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.retry import func_retry
from backend.util.settings import Settings
@@ -286,17 +286,11 @@ class UserCreditBase(ABC):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
)
if transaction.isActive:
return
async with db.locked_transaction(f"usr_trx_{user_id}"):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
)
if transaction.isActive:
return
user_balance, _ = await self._get_credits(user_id)
await CreditTransaction.prisma().update(
where={
@@ -1004,8 +998,8 @@ def get_block_costs() -> dict[str, list[BlockCost]]:
async def get_stripe_customer_id(user_id: str) -> str:
user = await get_user_by_id(user_id)
if user.stripe_customer_id:
return user.stripe_customer_id
if user.stripeCustomerId:
return user.stripeCustomerId
customer = stripe.Customer.create(
name=user.name or "",
@@ -1028,10 +1022,10 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig):
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
user = await get_user_by_id(user_id)
if not user.top_up_config:
if not user.topUpConfig:
return AutoTopUpConfig(threshold=0, amount=0)
return AutoTopUpConfig.model_validate(user.top_up_config)
return AutoTopUpConfig.model_validate(user.topUpConfig)
async def admin_get_user_history(

View File

@@ -1,5 +1,6 @@
import logging
import os
import zlib
from contextlib import asynccontextmanager
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
from uuid import uuid4
@@ -49,10 +50,6 @@ prisma = Prisma(
logger = logging.getLogger(__name__)
def is_connected():
return prisma.is_connected()
@conn_retry("Prisma", "Acquiring connection")
async def connect():
if prisma.is_connected():
@@ -87,50 +84,35 @@ TRANSACTION_TIMEOUT = 15000 # 15 seconds - Increased from 5s to prevent timeout
@asynccontextmanager
async def transaction(timeout: int = TRANSACTION_TIMEOUT):
async def transaction(timeout: int | None = None):
"""
Create a database transaction with optional timeout.
Args:
timeout: Transaction timeout in milliseconds. If None, uses TRANSACTION_TIMEOUT (15s).
"""
if timeout is None:
timeout = TRANSACTION_TIMEOUT
async with prisma.tx(timeout=timeout) as tx:
yield tx
@asynccontextmanager
async def locked_transaction(key: str, timeout: int = TRANSACTION_TIMEOUT):
async def locked_transaction(key: str, timeout: int | None = None):
"""
Create a transaction and take a per-key advisory *transaction* lock.
- Uses a 64-bit lock id via hashtextextended(key, 0) to avoid 32-bit collisions.
- Bound by lock_timeout and statement_timeout so it won't block indefinitely.
- Lock is held for the duration of the transaction and auto-released on commit/rollback.
Create a database transaction with advisory lock.
Args:
key: String lock key (e.g., "usr_trx_<uuid>").
timeout: Transaction/lock/statement timeout in milliseconds.
key: Lock key for advisory lock
timeout: Transaction timeout in milliseconds. If None, uses TRANSACTION_TIMEOUT (15s).
"""
if timeout is None:
timeout = TRANSACTION_TIMEOUT
lock_key = zlib.crc32(key.encode("utf-8"))
async with transaction(timeout=timeout) as tx:
# Ensure we don't wait longer than desired
# Note: SET LOCAL doesn't support parameterized queries, must use string interpolation
await tx.execute_raw(f"SET LOCAL statement_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
await tx.execute_raw(f"SET LOCAL lock_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
# Block until acquired or lock_timeout hits
try:
await tx.execute_raw(
"SELECT pg_advisory_xact_lock(hashtextextended($1, 0))",
key,
)
except Exception as e:
# Normalize PG's lock timeout error to TimeoutError for callers
if "lock timeout" in str(e).lower():
raise TimeoutError(
f"Could not acquire lock for key={key!r} within {timeout}ms"
) from e
raise
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
yield tx

View File

@@ -33,13 +33,12 @@ from prisma.types import (
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic import BaseModel, ConfigDict, JsonValue
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import type as type_utils
from backend.util.json import SafeJson
from backend.util.retry import func_retry
from backend.util.settings import Config
from backend.util.truncate import truncate
@@ -135,10 +134,6 @@ class GraphExecutionMeta(BaseDbModel):
default=None,
description="Error message if any",
)
activity_status: str | None = Field(
default=None,
description="AI-generated summary of what the agent did",
)
def to_db(self) -> GraphExecutionStats:
return GraphExecutionStats(
@@ -150,7 +145,6 @@ class GraphExecutionMeta(BaseDbModel):
node_count=self.node_exec_count,
node_error_count=self.node_error_count,
error=self.error,
activity_status=self.activity_status,
)
stats: Stats | None
@@ -195,7 +189,6 @@ class GraphExecutionMeta(BaseDbModel):
if isinstance(stats.error, Exception)
else stats.error
),
activity_status=stats.activity_status,
)
if stats
else None
@@ -318,30 +311,18 @@ class NodeExecutionResult(BaseModel):
@staticmethod
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
try:
stats = NodeExecutionStats.model_validate(_node_exec.stats or {})
except (ValueError, ValidationError):
stats = NodeExecutionStats()
if stats.cleared_inputs:
input_data: BlockInput = defaultdict()
for name, messages in stats.cleared_inputs.items():
input_data[name] = messages[-1] if messages else ""
elif _node_exec.executionData:
if _node_exec.executionData:
# Execution that has been queued for execution will persist its data.
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
else:
# For incomplete execution, executionData will not be yet available.
input_data: BlockInput = defaultdict()
for data in _node_exec.Input or []:
input_data[data.name] = type_utils.convert(data.data, type[Any])
output_data: CompletedBlockOutput = defaultdict(list)
if stats.cleared_outputs:
for name, messages in stats.cleared_outputs.items():
output_data[name].extend(messages)
else:
for data in _node_exec.Output or []:
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
for data in _node_exec.Output or []:
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
if graph_execution:
@@ -649,8 +630,6 @@ async def update_graph_execution_stats(
"OR": [
{"executionStatus": ExecutionStatus.RUNNING},
{"executionStatus": ExecutionStatus.QUEUED},
# Terminated graph can be resumed.
{"executionStatus": ExecutionStatus.TERMINATED},
],
},
data=update_data,
@@ -667,6 +646,27 @@ async def update_graph_execution_stats(
return GraphExecution.from_db(graph_exec)
async def update_node_execution_stats(
node_exec_id: str, stats: NodeExecutionStats
) -> NodeExecutionResult:
data = stats.model_dump()
if isinstance(data["error"], Exception):
data["error"] = str(data["error"])
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data={
"stats": SafeJson(data),
"endedTime": datetime.now(tz=timezone.utc),
},
include=EXECUTION_RESULT_INCLUDE,
)
if not res:
raise ValueError(f"Node execution {node_exec_id} not found.")
return NodeExecutionResult.from_db(res)
async def update_node_execution_status_batch(
node_exec_ids: list[str],
status: ExecutionStatus,
@@ -896,15 +896,15 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
def publish(self, res: GraphExecution | NodeExecutionResult):
if isinstance(res, GraphExecution):
self._publish_graph_exec_update(res)
self.publish_graph_exec_update(res)
else:
self._publish_node_exec_update(res)
self.publish_node_exec_update(res)
def _publish_node_exec_update(self, res: NodeExecutionResult):
def publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
def _publish_graph_exec_update(self, res: GraphExecution):
def publish_graph_exec_update(self, res: GraphExecution):
event = GraphExecutionEvent.model_validate(res.model_dump())
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
@@ -936,18 +936,17 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
def event_bus_name(self) -> str:
return config.execution_event_bus_name
@func_retry
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
if isinstance(res, GraphExecutionMeta):
await self._publish_graph_exec_update(res)
await self.publish_graph_exec_update(res)
else:
await self._publish_node_exec_update(res)
await self.publish_node_exec_update(res)
async def _publish_node_exec_update(self, res: NodeExecutionResult):
async def publish_node_exec_update(self, res: NodeExecutionResult):
event = NodeExecutionEvent.model_validate(res.model_dump())
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
async def _publish_graph_exec_update(self, res: GraphExecutionMeta):
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
# Add default empty values for compatibility
event_data = res.model_dump()

View File

@@ -1,109 +0,0 @@
import logging
from collections import defaultdict
from datetime import datetime
from prisma.enums import AgentExecutionStatus
from backend.data.execution import get_graph_executions
from backend.data.graph import get_graph_metadata
from backend.data.model import UserExecutionSummaryStats
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[SummaryData]")
async def get_user_execution_summary_data(
user_id: str, start_time: datetime, end_time: datetime
) -> UserExecutionSummaryStats:
"""Gather all summary data for a user in a time range.
This function fetches graph executions once and aggregates all required
statistics in a single pass for efficiency.
"""
try:
# Fetch graph executions once
executions = await get_graph_executions(
user_id=user_id,
created_time_gte=start_time,
created_time_lte=end_time,
)
# Initialize aggregation variables
total_credits_used = 0.0
total_executions = len(executions)
successful_runs = 0
failed_runs = 0
terminated_runs = 0
execution_times = []
agent_usage = defaultdict(int)
cost_by_graph_id = defaultdict(float)
# Single pass through executions to aggregate all stats
for execution in executions:
# Count execution statuses (including TERMINATED as failed)
if execution.status == AgentExecutionStatus.COMPLETED:
successful_runs += 1
elif execution.status == AgentExecutionStatus.FAILED:
failed_runs += 1
elif execution.status == AgentExecutionStatus.TERMINATED:
terminated_runs += 1
# Aggregate costs from stats
if execution.stats and hasattr(execution.stats, "cost"):
cost_in_dollars = execution.stats.cost / 100
total_credits_used += cost_in_dollars
cost_by_graph_id[execution.graph_id] += cost_in_dollars
# Collect execution times
if execution.stats and hasattr(execution.stats, "duration"):
execution_times.append(execution.stats.duration)
# Count agent usage
agent_usage[execution.graph_id] += 1
# Calculate derived stats
total_execution_time = sum(execution_times)
average_execution_time = (
total_execution_time / len(execution_times) if execution_times else 0
)
# Find most used agent
most_used_agent = "No agents used"
if agent_usage:
most_used_agent_id = max(agent_usage, key=lambda k: agent_usage[k])
try:
graph_meta = await get_graph_metadata(graph_id=most_used_agent_id)
most_used_agent = (
graph_meta.name if graph_meta else f"Agent {most_used_agent_id[:8]}"
)
except Exception:
logger.warning(f"Could not get metadata for graph {most_used_agent_id}")
most_used_agent = f"Agent {most_used_agent_id[:8]}"
# Convert graph_ids to agent names for cost breakdown
cost_breakdown = {}
for graph_id, cost in cost_by_graph_id.items():
try:
graph_meta = await get_graph_metadata(graph_id=graph_id)
agent_name = graph_meta.name if graph_meta else f"Agent {graph_id[:8]}"
except Exception:
logger.warning(f"Could not get metadata for graph {graph_id}")
agent_name = f"Agent {graph_id[:8]}"
cost_breakdown[agent_name] = cost
# Build the summary stats object (include terminated runs as failed)
return UserExecutionSummaryStats(
total_credits_used=total_credits_used,
total_executions=total_executions,
successful_runs=successful_runs,
failed_runs=failed_runs + terminated_runs,
most_used_agent=most_used_agent,
total_execution_time=total_execution_time,
average_execution_time=average_execution_time,
cost_breakdown=cost_breakdown,
)
except Exception as e:
logger.error(f"Failed to get user summary data: {e}")
raise DatabaseError(f"Failed to get user summary data: {e}") from e

View File

@@ -416,10 +416,6 @@ class GraphModel(Graph):
for_run: bool = False,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
):
"""
Validate graph structure and raise `ValueError` on issues.
For structured error reporting, use `validate_graph_get_errors`.
"""
self._validate_graph(self, for_run, nodes_input_masks)
for sub_graph in self.sub_graphs:
self._validate_graph(sub_graph, for_run, nodes_input_masks)
@@ -429,58 +425,15 @@ class GraphModel(Graph):
graph: BaseGraph,
for_run: bool = False,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> None:
errors = GraphModel._validate_graph_get_errors(
graph, for_run, nodes_input_masks
)
if errors:
# Just raise the first error for backward compatibility
first_error = next(iter(errors.values()))
first_field_error = next(iter(first_error.values()))
raise ValueError(first_field_error)
):
def is_tool_pin(name: str) -> bool:
return name.startswith("tools_^_")
def validate_graph_get_errors(
self,
for_run: bool = False,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> dict[str, dict[str, str]]:
"""
Validate graph and return structured errors per node.
Returns: dict[node_id, dict[field_name, error_message]]
"""
return {
**self._validate_graph_get_errors(self, for_run, nodes_input_masks),
**{
node_id: error
for sub_graph in self.sub_graphs
for node_id, error in self._validate_graph_get_errors(
sub_graph, for_run, nodes_input_masks
).items()
},
}
@staticmethod
def _validate_graph_get_errors(
graph: BaseGraph,
for_run: bool = False,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> dict[str, dict[str, str]]:
"""
Validate graph and return structured errors per node.
Returns: dict[node_id, dict[field_name, error_message]]
"""
# First, check for structural issues with the graph
try:
GraphModel._validate_graph_structure(graph)
except ValueError:
# If structural validation fails, we can't provide per-node errors
# so we re-raise as is
raise
# Collect errors per node
node_errors: dict[str, dict[str, str]] = defaultdict(dict)
def sanitize(name):
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if is_tool_pin(sanitized_name):
return "tools"
return sanitized_name
# Validate smart decision maker nodes
nodes_block = {
@@ -489,7 +442,7 @@ class GraphModel(Graph):
if (block := get_block(node.block_id)) is not None
}
input_links: dict[str, list[Link]] = defaultdict(list)
input_links = defaultdict(list)
for link in graph.links:
input_links[link.sink_id].append(link)
@@ -497,22 +450,17 @@ class GraphModel(Graph):
# Nodes: required fields are filled or connected and dependencies are satisfied
for node in graph.nodes:
if (block := nodes_block.get(node.id)) is None:
# For invalid blocks, we still raise immediately as this is a structural issue
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
node_input_mask = (
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
)
provided_inputs = set(
[_sanitize_pin_name(name) for name in node.input_default]
+ [
_sanitize_pin_name(link.sink_name)
for link in input_links.get(node.id, [])
]
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
+ ([name for name in node_input_mask] if node_input_mask else [])
)
InputSchema = block.input_schema
for name in (required_fields := InputSchema.get_required_fields()):
if (
name not in provided_inputs
@@ -529,16 +477,18 @@ class GraphModel(Graph):
]
)
):
node_errors[node.id][name] = "This field is required"
raise ValueError(
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
if (
block.block_type == BlockType.INPUT
and (input_key := node.input_default.get("name"))
and is_credentials_field_name(input_key)
):
node_errors[node.id]["name"] = (
f"'{input_key}' is a reserved input name: "
"'credentials' and `*_credentials` are reserved"
raise ValueError(
f"Agent input node uses reserved name '{input_key}'; "
"'credentials' and `*_credentials` are reserved input names"
)
# Get input schema properties and check dependencies
@@ -588,15 +538,10 @@ class GraphModel(Graph):
# Check for missing dependencies when dependent field is present
missing_deps = [dep for dep in dependencies if not has_value(node, dep)]
if missing_deps and (field_has_value or field_is_required):
node_errors[node.id][
field_name
] = f"Requires {', '.join(missing_deps)} to be set"
raise ValueError(
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
)
return node_errors
@staticmethod
def _validate_graph_structure(graph: BaseGraph):
"""Validate graph structure (links, connections, etc.)"""
node_map = {v.id: v for v in graph.nodes}
def is_static_output_block(nid: str) -> bool:
@@ -622,7 +567,7 @@ class GraphModel(Graph):
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
)
sanitized_name = _sanitize_pin_name(name)
sanitized_name = sanitize(name)
vals = node.input_default
if i == 0:
fields = (
@@ -636,7 +581,7 @@ class GraphModel(Graph):
if block.block_type not in [BlockType.AGENT]
else vals.get("input_schema", {}).get("properties", {}).keys()
)
if sanitized_name not in fields and not _is_tool_pin(name):
if sanitized_name not in fields and not is_tool_pin(name):
fields_msg = f"Allowed fields: {fields}"
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
@@ -673,17 +618,6 @@ class GraphModel(Graph):
)
def _is_tool_pin(name: str) -> bool:
return name.startswith("tools_^_")
def _sanitize_pin_name(name: str) -> str:
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if _is_tool_pin(sanitized_name):
return "tools"
return sanitized_name
class GraphMeta(Graph):
user_id: str

View File

@@ -5,7 +5,6 @@ import enum
import logging
from collections import defaultdict
from datetime import datetime, timezone
from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Annotated,
@@ -41,120 +40,12 @@ from pydantic_core import (
from typing_extensions import TypedDict
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.settings import Secrets
# Type alias for any provider name (including custom ones)
AnyProviderName = str # Will be validated as ProviderName at runtime
class User(BaseModel):
"""Application-layer User model with snake_case convention."""
model_config = ConfigDict(
extra="forbid",
str_strip_whitespace=True,
)
id: str = Field(..., description="User ID")
email: str = Field(..., description="User email address")
email_verified: bool = Field(default=True, description="Whether email is verified")
name: Optional[str] = Field(None, description="User display name")
created_at: datetime = Field(..., description="When user was created")
updated_at: datetime = Field(..., description="When user was last updated")
metadata: dict[str, Any] = Field(
default_factory=dict, description="User metadata as dict"
)
integrations: str = Field(default="", description="Encrypted integrations data")
stripe_customer_id: Optional[str] = Field(None, description="Stripe customer ID")
top_up_config: Optional["AutoTopUpConfig"] = Field(
None, description="Top up configuration"
)
# Notification preferences
max_emails_per_day: int = Field(default=3, description="Maximum emails per day")
notify_on_agent_run: bool = Field(default=True, description="Notify on agent run")
notify_on_zero_balance: bool = Field(
default=True, description="Notify on zero balance"
)
notify_on_low_balance: bool = Field(
default=True, description="Notify on low balance"
)
notify_on_block_execution_failed: bool = Field(
default=True, description="Notify on block execution failure"
)
notify_on_continuous_agent_error: bool = Field(
default=True, description="Notify on continuous agent error"
)
notify_on_daily_summary: bool = Field(
default=True, description="Notify on daily summary"
)
notify_on_weekly_summary: bool = Field(
default=True, description="Notify on weekly summary"
)
notify_on_monthly_summary: bool = Field(
default=True, description="Notify on monthly summary"
)
@classmethod
def from_db(cls, prisma_user: "PrismaUser") -> "User":
"""Convert a database User object to application User model."""
# Handle metadata field - convert from JSON string or dict to dict
metadata = {}
if prisma_user.metadata:
if isinstance(prisma_user.metadata, str):
try:
metadata = json_loads(prisma_user.metadata)
except (JSONDecodeError, TypeError):
metadata = {}
elif isinstance(prisma_user.metadata, dict):
metadata = prisma_user.metadata
# Handle topUpConfig field
top_up_config = None
if prisma_user.topUpConfig:
if isinstance(prisma_user.topUpConfig, str):
try:
config_dict = json_loads(prisma_user.topUpConfig)
top_up_config = AutoTopUpConfig.model_validate(config_dict)
except (JSONDecodeError, TypeError, ValueError):
top_up_config = None
elif isinstance(prisma_user.topUpConfig, dict):
try:
top_up_config = AutoTopUpConfig.model_validate(
prisma_user.topUpConfig
)
except ValueError:
top_up_config = None
return cls(
id=prisma_user.id,
email=prisma_user.email,
email_verified=prisma_user.emailVerified or True,
name=prisma_user.name,
created_at=prisma_user.createdAt,
updated_at=prisma_user.updatedAt,
metadata=metadata,
integrations=prisma_user.integrations or "",
stripe_customer_id=prisma_user.stripeCustomerId,
top_up_config=top_up_config,
max_emails_per_day=prisma_user.maxEmailsPerDay or 3,
notify_on_agent_run=prisma_user.notifyOnAgentRun or True,
notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True,
notify_on_low_balance=prisma_user.notifyOnLowBalance or True,
notify_on_block_execution_failed=prisma_user.notifyOnBlockExecutionFailed
or True,
notify_on_continuous_agent_error=prisma_user.notifyOnContinuousAgentError
or True,
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
)
if TYPE_CHECKING:
from prisma.models import User as PrismaUser
from backend.data.block import BlockSchema
T = TypeVar("T")
@@ -753,7 +644,7 @@ class NodeExecutionStats(BaseModel):
arbitrary_types_allowed=True,
)
error: Optional[BaseException | str] = None
error: Optional[Exception | str] = None
walltime: float = 0
cputime: float = 0
input_size: int = 0
@@ -764,9 +655,6 @@ class NodeExecutionStats(BaseModel):
output_token_count: int = 0
extra_cost: int = 0
extra_steps: int = 0
# Moderation fields
cleared_inputs: Optional[dict[str, list[str]]] = None
cleared_outputs: Optional[dict[str, list[str]]] = None
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
"""Mutate this instance by adding another NodeExecutionStats."""
@@ -818,24 +706,3 @@ class GraphExecutionStats(BaseModel):
default=0, description="Total number of errors generated"
)
cost: int = Field(default=0, description="Total execution cost (cents)")
activity_status: Optional[str] = Field(
default=None, description="AI-generated summary of what the agent did"
)
class UserExecutionSummaryStats(BaseModel):
"""Summary of user statistics for a specific user."""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
)
total_credits_used: float = Field(default=0)
total_executions: int = Field(default=0)
successful_runs: int = Field(default=0)
failed_runs: int = Field(default=0)
most_used_agent: str = Field(default="")
total_execution_time: float = Field(default=0)
average_execution_time: float = Field(default=0)
cost_breakdown: dict[str, float] = Field(default_factory=dict)

View File

@@ -4,12 +4,20 @@ from enum import Enum
from typing import Awaitable, Optional
import aio_pika
import aio_pika.exceptions as aio_ex
import pika
import pika.adapters.blocking_connection
from pika.exceptions import AMQPError
from pika.spec import BasicProperties
from pydantic import BaseModel
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from backend.util.retry import conn_retry, func_retry
from backend.util.retry import conn_retry
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -140,7 +148,6 @@ class SyncRabbitMQ(RabbitMQBase):
socket_timeout=SOCKET_TIMEOUT,
connection_attempts=CONNECTION_ATTEMPTS,
retry_delay=RETRY_DELAY,
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
)
self._connection = pika.BlockingConnection(parameters)
@@ -191,7 +198,12 @@ class SyncRabbitMQ(RabbitMQBase):
routing_key=queue.routing_key or queue.name,
)
@func_retry
@retry(
retry=retry_if_exception_type((AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
def publish_message(
self,
routing_key: str,
@@ -245,7 +257,6 @@ class AsyncRabbitMQ(RabbitMQBase):
password=self.password,
virtualhost=self.config.vhost.lstrip("/"),
blocked_connection_timeout=BLOCKED_CONNECTION_TIMEOUT,
heartbeat=300, # 5 minute timeout (heartbeats sent every 2.5 min)
)
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)
@@ -291,7 +302,12 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name
)
@func_retry
@retry(
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
async def publish_message(
self,
routing_key: str,

View File

@@ -9,11 +9,11 @@ from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
from fastapi import HTTPException
from prisma.enums import NotificationType
from prisma.models import User as PrismaUser
from prisma.models import User
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
from backend.data.db import prisma
from backend.data.model import User, UserIntegrations, UserMetadata
from backend.data.model import UserIntegrations, UserMetadata, UserMetadataRaw
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.encryption import JSONCryptor
@@ -21,7 +21,6 @@ from backend.util.json import SafeJson
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
async def get_or_create_user(user_data: dict) -> User:
@@ -44,7 +43,7 @@ async def get_or_create_user(user_data: dict) -> User:
)
)
return User.from_db(user)
return User.model_validate(user)
except Exception as e:
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
@@ -53,7 +52,7 @@ async def get_user_by_id(user_id: str) -> User:
user = await prisma.user.find_unique(where={"id": user_id})
if not user:
raise ValueError(f"User not found with ID: {user_id}")
return User.from_db(user)
return User.model_validate(user)
async def get_user_email_by_id(user_id: str) -> Optional[str]:
@@ -67,7 +66,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
async def get_user_by_email(email: str) -> Optional[User]:
try:
user = await prisma.user.find_unique(where={"email": email})
return User.from_db(user) if user else None
return User.model_validate(user) if user else None
except Exception as e:
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
@@ -91,11 +90,27 @@ async def create_default_user() -> Optional[User]:
name="Default User",
)
)
return User.from_db(user)
return User.model_validate(user)
async def get_user_metadata(user_id: str) -> UserMetadata:
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
metadata = cast(UserMetadataRaw, user.metadata)
return UserMetadata.model_validate(metadata)
async def update_user_metadata(user_id: str, metadata: UserMetadata):
await User.prisma().update(
where={"id": user_id},
data={"metadata": SafeJson(metadata.model_dump())},
)
async def get_user_integrations(user_id: str) -> UserIntegrations:
user = await PrismaUser.prisma().find_unique_or_raise(
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
@@ -110,7 +125,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
async def update_user_integrations(user_id: str, data: UserIntegrations):
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
await PrismaUser.prisma().update(
await User.prisma().update(
where={"id": user_id},
data={"integrations": encrypted_data},
)
@@ -118,7 +133,7 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
async def migrate_and_encrypt_user_integrations():
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
users = await PrismaUser.prisma().find_many(
users = await User.prisma().find_many(
where={
"metadata": cast(
JsonFilter,
@@ -154,7 +169,7 @@ async def migrate_and_encrypt_user_integrations():
raw_metadata.pop("integration_oauth_states", None)
# Update metadata without integration data
await PrismaUser.prisma().update(
await User.prisma().update(
where={"id": user.id},
data={"metadata": SafeJson(raw_metadata)},
)
@@ -162,7 +177,7 @@ async def migrate_and_encrypt_user_integrations():
async def get_active_user_ids_in_timerange(start_time: str, end_time: str) -> list[str]:
try:
users = await PrismaUser.prisma().find_many(
users = await User.prisma().find_many(
where={
"AgentGraphExecutions": {
"some": {
@@ -192,7 +207,7 @@ async def get_active_users_ids() -> list[str]:
async def get_user_notification_preference(user_id: str) -> NotificationPreference:
try:
user = await PrismaUser.prisma().find_unique_or_raise(
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
@@ -269,7 +284,7 @@ async def update_user_notification_preference(
if data.daily_limit:
update_data["maxEmailsPerDay"] = data.daily_limit
user = await PrismaUser.prisma().update(
user = await User.prisma().update(
where={"id": user_id},
data=update_data,
)
@@ -307,7 +322,7 @@ async def update_user_notification_preference(
async def set_user_email_verification(user_id: str, verified: bool) -> None:
"""Set the email verification status for a user."""
try:
await PrismaUser.prisma().update(
await User.prisma().update(
where={"id": user_id},
data={"emailVerified": verified},
)
@@ -320,7 +335,7 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
async def get_user_email_verification(user_id: str) -> bool:
"""Get the email verification status for a user."""
try:
user = await PrismaUser.prisma().find_unique_or_raise(
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
return user.emailVerified
@@ -333,7 +348,7 @@ async def get_user_email_verification(user_id: str) -> bool:
def generate_unsubscribe_link(user_id: str) -> str:
"""Generate a link to unsubscribe from all notifications"""
# Create an HMAC using a secret key
secret_key = settings.secrets.unsubscribe_secret_key
secret_key = Settings().secrets.unsubscribe_secret_key
signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()
@@ -344,7 +359,7 @@ def generate_unsubscribe_link(user_id: str) -> str:
).decode("utf-8")
logger.info(f"Generating unsubscribe link for user {user_id}")
base_url = settings.config.platform_base_url
base_url = Settings().config.platform_base_url
return f"{base_url}/api/email/unsubscribe?token={quote_plus(token)}"
@@ -356,7 +371,7 @@ async def unsubscribe_user_by_token(token: str) -> None:
user_id, received_signature_hex = decoded.split(":", 1)
# Verify the signature
secret_key = settings.secrets.unsubscribe_secret_key
secret_key = Settings().secrets.unsubscribe_secret_key
expected_signature = hmac.new(
secret_key.encode("utf-8"), user_id.encode("utf-8"), hashlib.sha256
).digest()

View File

@@ -1,434 +0,0 @@
"""
Module for generating AI-based activity status for graph executions.
"""
import json
import logging
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
from pydantic import SecretStr
from backend.blocks.llm import LlmModel, llm_call
from backend.data.block import get_block
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.data.model import APIKeyCredentials, GraphExecutionStats
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.retry import func_retry
from backend.util.settings import Settings
from backend.util.truncate import truncate
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient
logger = logging.getLogger(__name__)
class ErrorInfo(TypedDict):
"""Type definition for error information."""
error: str
execution_id: str
timestamp: str
class InputOutputInfo(TypedDict):
"""Type definition for input/output information."""
execution_id: str
output_data: dict[str, Any] # Used for both input and output data
timestamp: str
class NodeInfo(TypedDict):
"""Type definition for node information."""
node_id: str
block_id: str
block_name: str
block_description: str
execution_count: int
error_count: int
recent_errors: list[ErrorInfo]
recent_outputs: list[InputOutputInfo]
recent_inputs: list[InputOutputInfo]
class NodeRelation(TypedDict):
"""Type definition for node relation information."""
source_node_id: str
sink_node_id: str
source_name: str
sink_name: str
is_static: bool
source_block_name: NotRequired[str] # Optional, only set if block exists
sink_block_name: NotRequired[str] # Optional, only set if block exists
def _truncate_uuid(uuid_str: str) -> str:
"""Truncate UUID to first segment to reduce payload size."""
if not uuid_str:
return uuid_str
return uuid_str.split("-")[0] if "-" in uuid_str else uuid_str[:8]
async def generate_activity_status_for_execution(
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_stats: GraphExecutionStats,
db_client: "DatabaseManagerAsyncClient",
user_id: str,
execution_status: ExecutionStatus | None = None,
) -> str | None:
"""
Generate an AI-based activity status summary for a graph execution.
This function handles all the data collection and AI generation logic,
keeping the manager integration simple.
Args:
graph_exec_id: The graph execution ID
graph_id: The graph ID
graph_version: The graph version
execution_stats: Execution statistics
db_client: Database client for fetching data
user_id: User ID for LaunchDarkly feature flag evaluation
execution_status: The overall execution status (COMPLETED, FAILED, TERMINATED)
Returns:
AI-generated activity status string, or None if feature is disabled
"""
# Check LaunchDarkly feature flag for AI activity status generation with full context support
if not await is_feature_enabled(Flag.AI_ACTIVITY_STATUS, user_id):
logger.debug("AI activity status generation is disabled via LaunchDarkly")
return None
# Check if we have OpenAI API key
try:
settings = Settings()
if not settings.secrets.openai_api_key:
logger.debug(
"OpenAI API key not configured, skipping activity status generation"
)
return None
# Get all node executions for this graph execution
node_executions = await db_client.get_node_executions(
graph_exec_id, include_exec_data=True
)
# Get graph metadata and full graph structure for name, description, and links
graph_metadata = await db_client.get_graph_metadata(graph_id, graph_version)
graph = await db_client.get_graph(graph_id, graph_version)
graph_name = graph_metadata.name if graph_metadata else f"Graph {graph_id}"
graph_description = graph_metadata.description if graph_metadata else ""
graph_links = graph.links if graph else []
# Build execution data summary
execution_data = _build_execution_summary(
node_executions,
execution_stats,
graph_name,
graph_description,
graph_links,
execution_status,
)
# Prepare prompt for AI
prompt = [
{
"role": "system",
"content": (
"You are an AI assistant summarizing what you just did for a user in simple, friendly language. "
"Write from the user's perspective about what they accomplished, NOT about technical execution details. "
"Focus on the ACTUAL TASK the user wanted done, not the internal workflow steps. "
"Avoid technical terms like 'workflow', 'execution', 'components', 'nodes', 'processing', etc. "
"Keep it to 3 sentences maximum. Be conversational and human-friendly.\n\n"
"IMPORTANT: Be HONEST about what actually happened:\n"
"- If the input was invalid/nonsensical, say so directly\n"
"- If the task failed, explain what went wrong in simple terms\n"
"- If errors occurred, focus on what the user needs to know\n"
"- Only claim success if the task was genuinely completed\n"
"- Don't sugar-coat failures or present them as helpful feedback\n\n"
"Understanding Errors:\n"
"- Node errors: Individual steps may fail but the overall task might still complete (e.g., one data source fails but others work)\n"
"- Graph error (in overall_status.graph_error): This means the entire execution failed and nothing was accomplished\n"
"- Even if execution shows 'completed', check if critical nodes failed that would prevent the desired outcome\n"
"- Focus on the end result the user wanted, not whether technical steps completed"
),
},
{
"role": "user",
"content": (
f"A user ran '{graph_name}' to accomplish something. Based on this execution data, "
f"write what they achieved in simple, user-friendly terms:\n\n"
f"{json.dumps(execution_data, indent=2)}\n\n"
"CRITICAL: Check overall_status.graph_error FIRST - if present, the entire execution failed.\n"
"Then check individual node errors to understand partial failures.\n\n"
"Write 1-3 sentences about what the user accomplished, such as:\n"
"- 'I analyzed your resume and provided detailed feedback for the IT industry.'\n"
"- 'I couldn't analyze your resume because the input was just nonsensical text.'\n"
"- 'I failed to complete the task due to missing API access.'\n"
"- 'I extracted key information from your documents and organized it into a summary.'\n"
"- 'The task failed to run due to system configuration issues.'\n\n"
"Focus on what ACTUALLY happened, not what was attempted."
),
},
]
# Log the prompt for debugging purposes
logger.debug(
f"Sending prompt to LLM for graph execution {graph_exec_id}: {json.dumps(prompt, indent=2)}"
)
# Create credentials for LLM call
credentials = APIKeyCredentials(
id="openai",
provider="openai",
api_key=SecretStr(settings.secrets.openai_api_key),
title="System OpenAI",
)
# Make LLM call using current event loop
activity_status = await _call_llm_direct(credentials, prompt)
logger.debug(
f"Generated activity status for {graph_exec_id}: {activity_status}"
)
return activity_status
except Exception as e:
logger.error(
f"Failed to generate activity status for execution {graph_exec_id}: {str(e)}"
)
return None
def _build_execution_summary(
node_executions: list[NodeExecutionResult],
execution_stats: GraphExecutionStats,
graph_name: str,
graph_description: str,
graph_links: list[Any],
execution_status: ExecutionStatus | None = None,
) -> dict[str, Any]:
"""Build a structured summary of execution data for AI analysis."""
nodes: list[NodeInfo] = []
node_execution_counts: dict[str, int] = {}
node_error_counts: dict[str, int] = {}
node_errors: dict[str, list[ErrorInfo]] = {}
node_outputs: dict[str, list[InputOutputInfo]] = {}
node_inputs: dict[str, list[InputOutputInfo]] = {}
input_output_data: dict[str, Any] = {}
node_map: dict[str, NodeInfo] = {}
# Process node executions
for node_exec in node_executions:
block = get_block(node_exec.block_id)
if not block:
logger.warning(
f"Block {node_exec.block_id} not found for node {node_exec.node_id}"
)
continue
# Track execution counts per node
if node_exec.node_id not in node_execution_counts:
node_execution_counts[node_exec.node_id] = 0
node_execution_counts[node_exec.node_id] += 1
# Track errors per node and group them
if node_exec.status == ExecutionStatus.FAILED:
if node_exec.node_id not in node_error_counts:
node_error_counts[node_exec.node_id] = 0
node_error_counts[node_exec.node_id] += 1
# Extract actual error message from output_data
error_message = "Unknown error"
if node_exec.output_data and isinstance(node_exec.output_data, dict):
# Check if error is in output_data
if "error" in node_exec.output_data:
error_output = node_exec.output_data["error"]
if isinstance(error_output, list) and error_output:
error_message = str(error_output[0])
else:
error_message = str(error_output)
# Group errors by node_id
if node_exec.node_id not in node_errors:
node_errors[node_exec.node_id] = []
node_errors[node_exec.node_id].append(
{
"error": error_message,
"execution_id": _truncate_uuid(node_exec.node_exec_id),
"timestamp": node_exec.add_time.isoformat(),
}
)
# Collect output samples for each node (latest executions)
if node_exec.output_data:
if node_exec.node_id not in node_outputs:
node_outputs[node_exec.node_id] = []
# Truncate output data to 100 chars to save space
truncated_output = truncate(node_exec.output_data, 100)
node_outputs[node_exec.node_id].append(
{
"execution_id": _truncate_uuid(node_exec.node_exec_id),
"output_data": truncated_output,
"timestamp": node_exec.add_time.isoformat(),
}
)
# Collect input samples for each node (latest executions)
if node_exec.input_data:
if node_exec.node_id not in node_inputs:
node_inputs[node_exec.node_id] = []
# Truncate input data to 100 chars to save space
truncated_input = truncate(node_exec.input_data, 100)
node_inputs[node_exec.node_id].append(
{
"execution_id": _truncate_uuid(node_exec.node_exec_id),
"output_data": truncated_input, # Reuse field name for consistency
"timestamp": node_exec.add_time.isoformat(),
}
)
# Build node data (only add unique nodes)
if node_exec.node_id not in node_map:
node_data: NodeInfo = {
"node_id": _truncate_uuid(node_exec.node_id),
"block_id": _truncate_uuid(node_exec.block_id),
"block_name": block.name,
"block_description": block.description or "",
"execution_count": 0, # Will be set later
"error_count": 0, # Will be set later
"recent_errors": [], # Will be set later
"recent_outputs": [], # Will be set later
"recent_inputs": [], # Will be set later
}
nodes.append(node_data)
node_map[node_exec.node_id] = node_data
# Store input/output data for special blocks (input/output blocks)
if block.name in ["AgentInputBlock", "AgentOutputBlock", "UserInputBlock"]:
if node_exec.input_data:
input_output_data[f"{node_exec.node_id}_inputs"] = dict(
node_exec.input_data
)
if node_exec.output_data:
input_output_data[f"{node_exec.node_id}_outputs"] = dict(
node_exec.output_data
)
# Add execution and error counts to node data, plus limited errors and output samples
for node in nodes:
# Use original node_id for lookups (before truncation)
original_node_id = None
for orig_id, node_data in node_map.items():
if node_data == node:
original_node_id = orig_id
break
if original_node_id:
node["execution_count"] = node_execution_counts.get(original_node_id, 0)
node["error_count"] = node_error_counts.get(original_node_id, 0)
# Add limited errors for this node (latest 10 or first 5 + last 5)
if original_node_id in node_errors:
node_error_list = node_errors[original_node_id]
if len(node_error_list) <= 10:
node["recent_errors"] = node_error_list
else:
# First 5 + last 5 if more than 10 errors
node["recent_errors"] = node_error_list[:5] + node_error_list[-5:]
# Add latest output samples (latest 3)
if original_node_id in node_outputs:
node_output_list = node_outputs[original_node_id]
# Sort by timestamp if available, otherwise take last 3
if node_output_list and node_output_list[0].get("timestamp"):
node_output_list.sort(
key=lambda x: x.get("timestamp", ""), reverse=True
)
node["recent_outputs"] = node_output_list[:3]
# Add latest input samples (latest 3)
if original_node_id in node_inputs:
node_input_list = node_inputs[original_node_id]
# Sort by timestamp if available, otherwise take last 3
if node_input_list and node_input_list[0].get("timestamp"):
node_input_list.sort(
key=lambda x: x.get("timestamp", ""), reverse=True
)
node["recent_inputs"] = node_input_list[:3]
# Build node relations from graph links
node_relations: list[NodeRelation] = []
for link in graph_links:
# Include link details with source and sink information (truncated UUIDs)
relation: NodeRelation = {
"source_node_id": _truncate_uuid(link.source_id),
"sink_node_id": _truncate_uuid(link.sink_id),
"source_name": link.source_name,
"sink_name": link.sink_name,
"is_static": link.is_static if hasattr(link, "is_static") else False,
}
# Add block names if nodes exist in our map
if link.source_id in node_map:
relation["source_block_name"] = node_map[link.source_id]["block_name"]
if link.sink_id in node_map:
relation["sink_block_name"] = node_map[link.sink_id]["block_name"]
node_relations.append(relation)
# Build overall summary
return {
"graph_info": {"name": graph_name, "description": graph_description},
"nodes": nodes,
"node_relations": node_relations,
"input_output_data": input_output_data,
"overall_status": {
"total_nodes_in_graph": len(nodes),
"total_executions": execution_stats.node_count,
"total_errors": execution_stats.node_error_count,
"execution_time_seconds": execution_stats.walltime,
"has_errors": bool(
execution_stats.error or execution_stats.node_error_count > 0
),
"graph_error": (
str(execution_stats.error) if execution_stats.error else None
),
"graph_execution_status": (
execution_status.value if execution_status else None
),
},
}
@func_retry
async def _call_llm_direct(
credentials: APIKeyCredentials, prompt: list[dict[str, str]]
) -> str:
"""Make direct LLM call."""
response = await llm_call(
credentials=credentials,
llm_model=LlmModel.GPT4O_MINI,
prompt=prompt,
json_format=False,
max_tokens=150,
compress_prompt_to_fit=True,
)
if response and response.response:
return response.response.strip()
else:
return "Unable to generate activity summary"

View File

@@ -1,702 +0,0 @@
"""
Tests for activity status generator functionality.
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.llm import LLMResponse
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.data.model import GraphExecutionStats
from backend.executor.activity_status_generator import (
_build_execution_summary,
_call_llm_direct,
generate_activity_status_for_execution,
)
@pytest.fixture
def mock_node_executions():
"""Create mock node executions for testing."""
return [
NodeExecutionResult(
user_id="test_user",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec",
node_exec_id="123e4567-e89b-12d3-a456-426614174001",
node_id="456e7890-e89b-12d3-a456-426614174002",
block_id="789e1234-e89b-12d3-a456-426614174003",
status=ExecutionStatus.COMPLETED,
input_data={"user_input": "Hello, world!"},
output_data={"processed_input": ["Hello, world!"]},
add_time=datetime.now(timezone.utc),
queue_time=None,
start_time=None,
end_time=None,
),
NodeExecutionResult(
user_id="test_user",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec",
node_exec_id="234e5678-e89b-12d3-a456-426614174004",
node_id="567e8901-e89b-12d3-a456-426614174005",
block_id="890e2345-e89b-12d3-a456-426614174006",
status=ExecutionStatus.COMPLETED,
input_data={"data": "Hello, world!"},
output_data={"result": ["Processed data"]},
add_time=datetime.now(timezone.utc),
queue_time=None,
start_time=None,
end_time=None,
),
NodeExecutionResult(
user_id="test_user",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec",
node_exec_id="345e6789-e89b-12d3-a456-426614174007",
node_id="678e9012-e89b-12d3-a456-426614174008",
block_id="901e3456-e89b-12d3-a456-426614174009",
status=ExecutionStatus.FAILED,
input_data={"final_data": "Processed data"},
output_data={
"error": ["Connection timeout: Unable to reach external service"]
},
add_time=datetime.now(timezone.utc),
queue_time=None,
start_time=None,
end_time=None,
),
]
@pytest.fixture
def mock_execution_stats():
"""Create mock execution stats for testing."""
return GraphExecutionStats(
walltime=2.5,
cputime=1.8,
nodes_walltime=2.0,
nodes_cputime=1.5,
node_count=3,
node_error_count=1,
cost=10,
error=None,
)
@pytest.fixture
def mock_execution_stats_with_graph_error():
"""Create mock execution stats with graph-level error."""
return GraphExecutionStats(
walltime=2.5,
cputime=1.8,
nodes_walltime=2.0,
nodes_cputime=1.5,
node_count=3,
node_error_count=1,
cost=10,
error="Graph execution failed: Invalid API credentials",
)
@pytest.fixture
def mock_blocks():
"""Create mock blocks for testing."""
input_block = MagicMock()
input_block.name = "AgentInputBlock"
input_block.description = "Handles user input"
process_block = MagicMock()
process_block.name = "ProcessingBlock"
process_block.description = "Processes data"
output_block = MagicMock()
output_block.name = "AgentOutputBlock"
output_block.description = "Provides output to user"
return {
"789e1234-e89b-12d3-a456-426614174003": input_block,
"890e2345-e89b-12d3-a456-426614174006": process_block,
"901e3456-e89b-12d3-a456-426614174009": output_block,
"process_block_id": process_block, # Keep old key for different error format test
}
class TestBuildExecutionSummary:
"""Tests for _build_execution_summary function."""
def test_build_summary_with_successful_execution(
self, mock_node_executions, mock_execution_stats, mock_blocks
):
"""Test building summary for successful execution."""
# Create mock links with realistic UUIDs
mock_links = [
MagicMock(
source_id="456e7890-e89b-12d3-a456-426614174002",
sink_id="567e8901-e89b-12d3-a456-426614174005",
source_name="output",
sink_name="input",
is_static=False,
)
]
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block:
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
summary = _build_execution_summary(
mock_node_executions[:2],
mock_execution_stats,
"Test Graph",
"A test graph for processing",
mock_links,
ExecutionStatus.COMPLETED,
)
# Check graph info
assert summary["graph_info"]["name"] == "Test Graph"
assert summary["graph_info"]["description"] == "A test graph for processing"
# Check nodes with per-node counts
assert len(summary["nodes"]) == 2
assert summary["nodes"][0]["block_name"] == "AgentInputBlock"
assert summary["nodes"][0]["execution_count"] == 1
assert summary["nodes"][0]["error_count"] == 0
assert summary["nodes"][1]["block_name"] == "ProcessingBlock"
assert summary["nodes"][1]["execution_count"] == 1
assert summary["nodes"][1]["error_count"] == 0
# Check node relations (UUIDs are truncated to first segment)
assert len(summary["node_relations"]) == 1
assert (
summary["node_relations"][0]["source_node_id"] == "456e7890"
) # Truncated
assert (
summary["node_relations"][0]["sink_node_id"] == "567e8901"
) # Truncated
assert (
summary["node_relations"][0]["source_block_name"] == "AgentInputBlock"
)
assert summary["node_relations"][0]["sink_block_name"] == "ProcessingBlock"
# Check overall status
assert summary["overall_status"]["total_nodes_in_graph"] == 2
assert summary["overall_status"]["total_executions"] == 3
assert summary["overall_status"]["total_errors"] == 1
assert summary["overall_status"]["execution_time_seconds"] == 2.5
assert summary["overall_status"]["graph_execution_status"] == "COMPLETED"
# Check input/output data (using actual node UUIDs)
assert (
"456e7890-e89b-12d3-a456-426614174002_inputs"
in summary["input_output_data"]
)
assert (
"456e7890-e89b-12d3-a456-426614174002_outputs"
in summary["input_output_data"]
)
def test_build_summary_with_failed_execution(
self, mock_node_executions, mock_execution_stats, mock_blocks
):
"""Test building summary for execution with failures."""
mock_links = [] # No links for this test
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block:
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
summary = _build_execution_summary(
mock_node_executions,
mock_execution_stats,
"Failed Graph",
"Test with failures",
mock_links,
ExecutionStatus.FAILED,
)
# Check that errors are now in node's recent_errors field
# Find the output node (with truncated UUID)
output_node = next(
n for n in summary["nodes"] if n["node_id"] == "678e9012" # Truncated
)
assert output_node["error_count"] == 1
assert output_node["execution_count"] == 1
# Check recent_errors field
assert "recent_errors" in output_node
assert len(output_node["recent_errors"]) == 1
assert (
output_node["recent_errors"][0]["error"]
== "Connection timeout: Unable to reach external service"
)
assert (
"execution_id" in output_node["recent_errors"][0]
) # Should include execution ID
def test_build_summary_with_missing_blocks(
self, mock_node_executions, mock_execution_stats
):
"""Test building summary when blocks are missing."""
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block:
mock_get_block.return_value = None
summary = _build_execution_summary(
mock_node_executions,
mock_execution_stats,
"Missing Blocks Graph",
"Test with missing blocks",
[],
ExecutionStatus.COMPLETED,
)
# Should handle missing blocks gracefully
assert len(summary["nodes"]) == 0
# No top-level errors field anymore, errors are in nodes' recent_errors
assert summary["graph_info"]["name"] == "Missing Blocks Graph"
def test_build_summary_with_graph_error(
self, mock_node_executions, mock_execution_stats_with_graph_error, mock_blocks
):
"""Test building summary with graph-level error."""
mock_links = []
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block:
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
summary = _build_execution_summary(
mock_node_executions,
mock_execution_stats_with_graph_error,
"Graph with Error",
"Test with graph error",
mock_links,
ExecutionStatus.FAILED,
)
# Check that graph error is included in overall status
assert summary["overall_status"]["has_errors"] is True
assert (
summary["overall_status"]["graph_error"]
== "Graph execution failed: Invalid API credentials"
)
assert summary["overall_status"]["total_errors"] == 1
assert summary["overall_status"]["graph_execution_status"] == "FAILED"
def test_build_summary_with_different_error_formats(
self, mock_execution_stats, mock_blocks
):
"""Test building summary with different error formats."""
# Create node executions with different error formats and realistic UUIDs
mock_executions = [
NodeExecutionResult(
user_id="test_user",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec",
node_exec_id="111e2222-e89b-12d3-a456-426614174010",
node_id="333e4444-e89b-12d3-a456-426614174011",
block_id="process_block_id",
status=ExecutionStatus.FAILED,
input_data={},
output_data={"error": ["Simple string error message"]},
add_time=datetime.now(timezone.utc),
queue_time=None,
start_time=None,
end_time=None,
),
NodeExecutionResult(
user_id="test_user",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec",
node_exec_id="555e6666-e89b-12d3-a456-426614174012",
node_id="777e8888-e89b-12d3-a456-426614174013",
block_id="process_block_id",
status=ExecutionStatus.FAILED,
input_data={},
output_data={}, # No error in output
add_time=datetime.now(timezone.utc),
queue_time=None,
start_time=None,
end_time=None,
),
]
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block:
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
summary = _build_execution_summary(
mock_executions,
mock_execution_stats,
"Error Test Graph",
"Testing error formats",
[],
ExecutionStatus.FAILED,
)
# Check different error formats - errors are now in nodes' recent_errors
error_nodes = [n for n in summary["nodes"] if n.get("recent_errors")]
assert len(error_nodes) == 2
# String error format - find node with truncated ID
string_error_node = next(
n for n in summary["nodes"] if n["node_id"] == "333e4444" # Truncated
)
assert len(string_error_node["recent_errors"]) == 1
assert (
string_error_node["recent_errors"][0]["error"]
== "Simple string error message"
)
# No error output format - find node with truncated ID
no_error_node = next(
n for n in summary["nodes"] if n["node_id"] == "777e8888" # Truncated
)
assert len(no_error_node["recent_errors"]) == 1
assert no_error_node["recent_errors"][0]["error"] == "Unknown error"
class TestLLMCall:
"""Tests for LLM calling functionality."""
@pytest.mark.asyncio
async def test_call_llm_direct_success(self):
"""Test successful LLM call."""
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials
mock_response = LLMResponse(
raw_response={},
prompt=[],
response="Agent successfully processed user input and generated response.",
tool_calls=None,
prompt_tokens=50,
completion_tokens=20,
)
with patch(
"backend.executor.activity_status_generator.llm_call"
) as mock_llm_call:
mock_llm_call.return_value = mock_response
credentials = APIKeyCredentials(
id="test",
provider="openai",
api_key=SecretStr("test_key"),
title="Test",
)
prompt = [{"role": "user", "content": "Test prompt"}]
result = await _call_llm_direct(credentials, prompt)
assert (
result
== "Agent successfully processed user input and generated response."
)
mock_llm_call.assert_called_once()
@pytest.mark.asyncio
async def test_call_llm_direct_no_response(self):
"""Test LLM call with no response."""
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials
with patch(
"backend.executor.activity_status_generator.llm_call"
) as mock_llm_call:
mock_llm_call.return_value = None
credentials = APIKeyCredentials(
id="test",
provider="openai",
api_key=SecretStr("test_key"),
title="Test",
)
prompt = [{"role": "user", "content": "Test prompt"}]
result = await _call_llm_direct(credentials, prompt)
assert result == "Unable to generate activity summary"
class TestGenerateActivityStatusForExecution:
"""Tests for the main generate_activity_status_for_execution function."""
@pytest.mark.asyncio
async def test_generate_status_success(
self, mock_node_executions, mock_execution_stats, mock_blocks
):
"""Test successful activity status generation."""
mock_db_client = AsyncMock()
mock_db_client.get_node_executions.return_value = mock_node_executions
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_graph_metadata.description = "A test agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_graph = MagicMock()
mock_graph.links = []
mock_db_client.get_graph.return_value = mock_graph
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block, patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator._call_llm_direct"
) as mock_llm, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
):
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
mock_settings.return_value.secrets.openai_api_key = "test_key"
mock_llm.return_value = (
"I analyzed your data and provided the requested insights."
)
result = await generate_activity_status_for_execution(
graph_exec_id="test_exec",
graph_id="test_graph",
graph_version=1,
execution_stats=mock_execution_stats,
db_client=mock_db_client,
user_id="test_user",
)
assert result == "I analyzed your data and provided the requested insights."
mock_db_client.get_node_executions.assert_called_once()
mock_db_client.get_graph_metadata.assert_called_once()
mock_db_client.get_graph.assert_called_once()
mock_llm.assert_called_once()
@pytest.mark.asyncio
async def test_generate_status_feature_disabled(self, mock_execution_stats):
"""Test activity status generation when feature is disabled."""
mock_db_client = AsyncMock()
with patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=False,
):
result = await generate_activity_status_for_execution(
graph_exec_id="test_exec",
graph_id="test_graph",
graph_version=1,
execution_stats=mock_execution_stats,
db_client=mock_db_client,
user_id="test_user",
)
assert result is None
mock_db_client.get_node_executions.assert_not_called()
@pytest.mark.asyncio
async def test_generate_status_no_api_key(self, mock_execution_stats):
"""Test activity status generation with no API key."""
mock_db_client = AsyncMock()
with patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
):
mock_settings.return_value.secrets.openai_api_key = ""
result = await generate_activity_status_for_execution(
graph_exec_id="test_exec",
graph_id="test_graph",
graph_version=1,
execution_stats=mock_execution_stats,
db_client=mock_db_client,
user_id="test_user",
)
assert result is None
mock_db_client.get_node_executions.assert_not_called()
@pytest.mark.asyncio
async def test_generate_status_exception_handling(self, mock_execution_stats):
"""Test activity status generation with exception."""
mock_db_client = AsyncMock()
mock_db_client.get_node_executions.side_effect = Exception("Database error")
with patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
):
mock_settings.return_value.secrets.openai_api_key = "test_key"
result = await generate_activity_status_for_execution(
graph_exec_id="test_exec",
graph_id="test_graph",
graph_version=1,
execution_stats=mock_execution_stats,
db_client=mock_db_client,
user_id="test_user",
)
assert result is None
@pytest.mark.asyncio
async def test_generate_status_with_graph_name_fallback(
self, mock_node_executions, mock_execution_stats, mock_blocks
):
"""Test activity status generation with graph name fallback."""
mock_db_client = AsyncMock()
mock_db_client.get_node_executions.return_value = mock_node_executions
mock_db_client.get_graph_metadata.return_value = None # No metadata
mock_db_client.get_graph.return_value = None # No graph
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block, patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator._call_llm_direct"
) as mock_llm, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
):
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
mock_settings.return_value.secrets.openai_api_key = "test_key"
mock_llm.return_value = "Agent completed execution."
result = await generate_activity_status_for_execution(
graph_exec_id="test_exec",
graph_id="test_graph",
graph_version=1,
execution_stats=mock_execution_stats,
db_client=mock_db_client,
user_id="test_user",
)
assert result == "Agent completed execution."
# Should use fallback graph name in prompt
call_args = mock_llm.call_args[0][1] # prompt argument
assert "Graph test_graph" in call_args[1]["content"]
class TestIntegration:
"""Integration tests to verify the complete flow."""
@pytest.mark.asyncio
async def test_full_integration_flow(
self, mock_node_executions, mock_execution_stats, mock_blocks
):
"""Test the complete integration flow."""
mock_db_client = AsyncMock()
mock_db_client.get_node_executions.return_value = mock_node_executions
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Integration Agent"
mock_graph_metadata.description = "Integration test agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_graph = MagicMock()
mock_graph.links = []
mock_db_client.get_graph.return_value = mock_graph
expected_activity = "I processed user input but failed during final output generation due to system error."
with patch(
"backend.executor.activity_status_generator.get_block"
) as mock_get_block, patch(
"backend.executor.activity_status_generator.Settings"
) as mock_settings, patch(
"backend.executor.activity_status_generator.llm_call"
) as mock_llm_call, patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=True,
):
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
mock_settings.return_value.secrets.openai_api_key = "test_key"
mock_response = LLMResponse(
raw_response={},
prompt=[],
response=expected_activity,
tool_calls=None,
prompt_tokens=100,
completion_tokens=30,
)
mock_llm_call.return_value = mock_response
result = await generate_activity_status_for_execution(
graph_exec_id="test_exec",
graph_id="test_graph",
graph_version=1,
execution_stats=mock_execution_stats,
db_client=mock_db_client,
user_id="test_user",
)
assert result == expected_activity
# Verify the correct data was passed to LLM
llm_call_args = mock_llm_call.call_args
prompt = llm_call_args[1]["prompt"]
# Check system prompt
assert prompt[0]["role"] == "system"
assert "user's perspective" in prompt[0]["content"]
# Check user prompt contains expected data
user_content = prompt[1]["content"]
assert "Test Integration Agent" in user_content
assert "user-friendly terms" in user_content.lower()
# Verify that execution data is present in the prompt
assert "{" in user_content # Should contain JSON data
assert "overall_status" in user_content
@pytest.mark.asyncio
async def test_manager_integration_with_disabled_feature(
self, mock_execution_stats
):
"""Test that when feature returns None, manager doesn't set activity_status."""
mock_db_client = AsyncMock()
with patch(
"backend.executor.activity_status_generator.is_feature_enabled",
return_value=False,
):
result = await generate_activity_status_for_execution(
graph_exec_id="test_exec",
graph_id="test_graph",
graph_version=1,
execution_stats=mock_execution_stats,
db_client=mock_db_client,
user_id="test_user",
)
# Should return None when disabled
assert result is None
# Verify no database calls were made
mock_db_client.get_node_executions.assert_not_called()
mock_db_client.get_graph_metadata.assert_not_called()
mock_db_client.get_graph.assert_not_called()

View File

@@ -7,6 +7,7 @@ from backend.data.execution import (
create_graph_execution,
get_block_error_stats,
get_execution_kv_data,
get_graph_execution,
get_graph_execution_meta,
get_graph_executions,
get_latest_node_execution,
@@ -15,12 +16,12 @@ from backend.data.execution import (
set_execution_kv_data,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
update_node_execution_status,
update_node_execution_status_batch,
upsert_execution_input,
upsert_execution_output,
)
from backend.data.generate_data import get_user_execution_summary_data
from backend.data.graph import (
get_connected_output_nodes,
get_graph,
@@ -39,16 +40,12 @@ from backend.data.user import (
get_user_email_by_id,
get_user_email_verification,
get_user_integrations,
get_user_metadata,
get_user_notification_preference,
update_user_integrations,
update_user_metadata,
)
from backend.util.service import (
AppService,
AppServiceClient,
UnhealthyServiceError,
endpoint_to_sync,
expose,
)
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
from backend.util.settings import Config
config = Config()
@@ -80,11 +77,6 @@ class DatabaseManager(AppService):
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
async def health_check(self) -> str:
if not db.is_connected():
raise UnhealthyServiceError("Database is not connected")
return await super().health_check()
@classmethod
def get_port(cls) -> int:
return config.database_api_port
@@ -98,6 +90,7 @@ class DatabaseManager(AppService):
return cast(Callable[Concatenate[object, P], R], expose(f))
# Executions
get_graph_execution = _(get_graph_execution)
get_graph_executions = _(get_graph_executions)
get_graph_execution_meta = _(get_graph_execution_meta)
create_graph_execution = _(create_graph_execution)
@@ -108,6 +101,7 @@ class DatabaseManager(AppService):
update_node_execution_status_batch = _(update_node_execution_status_batch)
update_graph_execution_start_time = _(update_graph_execution_start_time)
update_graph_execution_stats = _(update_graph_execution_stats)
update_node_execution_stats = _(update_node_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
get_execution_kv_data = _(get_execution_kv_data)
@@ -125,6 +119,8 @@ class DatabaseManager(AppService):
get_credits = _(_get_credits, name="get_credits")
# User + User Metadata + User Integrations
get_user_metadata = _(get_user_metadata)
update_user_metadata = _(update_user_metadata)
get_user_integrations = _(get_user_integrations)
update_user_integrations = _(update_user_integrations)
@@ -145,9 +141,6 @@ class DatabaseManager(AppService):
get_user_notification_oldest_message_in_batch
)
# Summary data - async
get_user_execution_summary_data = _(get_user_execution_summary_data)
class DatabaseManagerClient(AppServiceClient):
d = DatabaseManager
@@ -158,23 +151,55 @@ class DatabaseManagerClient(AppServiceClient):
return DatabaseManager
# Executions
get_graph_execution = _(d.get_graph_execution)
get_graph_executions = _(d.get_graph_executions)
get_graph_execution_meta = _(d.get_graph_execution_meta)
create_graph_execution = _(d.create_graph_execution)
get_node_execution = _(d.get_node_execution)
get_node_executions = _(d.get_node_executions)
get_latest_node_execution = _(d.get_latest_node_execution)
update_node_execution_status = _(d.update_node_execution_status)
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
update_graph_execution_stats = _(d.update_graph_execution_stats)
update_node_execution_stats = _(d.update_node_execution_stats)
upsert_execution_input = _(d.upsert_execution_input)
upsert_execution_output = _(d.upsert_execution_output)
get_execution_kv_data = _(d.get_execution_kv_data)
set_execution_kv_data = _(d.set_execution_kv_data)
# Graphs
get_node = _(d.get_node)
get_graph = _(d.get_graph)
get_connected_output_nodes = _(d.get_connected_output_nodes)
get_graph_metadata = _(d.get_graph_metadata)
# Credits
spend_credits = _(d.spend_credits)
get_credits = _(d.get_credits)
# Summary data - async
get_user_execution_summary_data = _(d.get_user_execution_summary_data)
# User + User Metadata + User Integrations
get_user_metadata = _(d.get_user_metadata)
update_user_metadata = _(d.update_user_metadata)
get_user_integrations = _(d.get_user_integrations)
update_user_integrations = _(d.update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
get_user_email_by_id = _(d.get_user_email_by_id)
get_user_email_verification = _(d.get_user_email_verification)
get_user_notification_preference = _(d.get_user_notification_preference)
# Notifications - async
create_or_add_to_user_notification_batch = _(
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = _(d.empty_user_notification_batch)
get_all_batches_by_type = _(d.get_all_batches_by_type)
get_user_notification_batch = _(d.get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
d.get_user_notification_oldest_message_in_batch
)
# Block error monitoring
get_block_error_stats = _(d.get_block_error_stats)
@@ -200,28 +225,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output
update_graph_execution_stats = d.update_graph_execution_stats
update_node_execution_stats = d.update_node_execution_stats
update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch
update_user_integrations = d.update_user_integrations
get_execution_kv_data = d.get_execution_kv_data
set_execution_kv_data = d.set_execution_kv_data
# User Comms
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
get_user_email_by_id = d.get_user_email_by_id
get_user_email_verification = d.get_user_email_verification
get_user_notification_preference = d.get_user_notification_preference
# Notifications
create_or_add_to_user_notification_batch = (
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = d.empty_user_notification_batch
get_all_batches_by_type = d.get_all_batches_by_type
get_user_notification_batch = d.get_user_notification_batch
get_user_notification_oldest_message_in_batch = (
d.get_user_notification_oldest_message_in_batch
)
# Summary data
get_user_execution_summary_data = d.get_user_execution_summary_data
get_block_error_stats = d.get_block_error_stats

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ import logging
import autogpt_libs.auth.models
import fastapi.responses
import pytest
from prisma.models import User
import backend.server.v2.library.model
import backend.server.v2.store.model
@@ -11,7 +12,6 @@ from backend.blocks.data_manipulation import FindInDictionaryBlock
from backend.blocks.io import AgentInputBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.data.model import User
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user

View File

@@ -1,22 +1,17 @@
import asyncio
import logging
import os
import threading
from enum import Enum
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from apscheduler.events import (
EVENT_JOB_ERROR,
EVENT_JOB_EXECUTED,
EVENT_JOB_MAX_INSTANCES,
EVENT_JOB_MISSED,
)
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
from apscheduler.job import Job as JobObj
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.schedulers.blocking import BlockingScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
@@ -35,14 +30,7 @@ from backend.monitoring import (
from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.logging import PrefixFilter
from backend.util.retry import func_retry
from backend.util.service import (
AppService,
AppServiceClient,
UnhealthyServiceError,
endpoint_to_async,
expose,
)
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
from backend.util.settings import Config
@@ -72,69 +60,26 @@ apscheduler_logger.addFilter(PrefixFilter("[Scheduler] [APScheduler]"))
config = Config()
# Timeout constants
SCHEDULER_OPERATION_TIMEOUT_SECONDS = 300 # 5 minutes for scheduler operations
def job_listener(event):
"""Logs job execution outcomes for better monitoring."""
if event.exception:
logger.error(
f"Job {event.job_id} failed: {type(event.exception).__name__}: {event.exception}"
)
logger.error(f"Job {event.job_id} failed.")
else:
logger.info(f"Job {event.job_id} completed successfully.")
def job_missed_listener(event):
"""Logs when jobs are missed due to scheduling issues."""
logger.warning(
f"Job {event.job_id} was missed at scheduled time {event.scheduled_run_time}. "
f"This can happen if the scheduler is overloaded or if previous executions are still running."
)
def job_max_instances_listener(event):
"""Logs when jobs hit max instances limit."""
logger.warning(
f"Job {event.job_id} execution was SKIPPED - max instances limit reached. "
f"Previous execution(s) are still running. "
f"Consider increasing max_instances or check why previous executions are taking too long."
)
_event_loop: asyncio.AbstractEventLoop | None = None
_event_loop_thread: threading.Thread | None = None
@func_retry
@thread_cached
def get_event_loop():
"""Get the shared event loop."""
if _event_loop is None:
raise RuntimeError("Event loop not initialized. Scheduler not started.")
return _event_loop
def run_async(coro, timeout: float = SCHEDULER_OPERATION_TIMEOUT_SECONDS):
"""Run a coroutine in the shared event loop and wait for completion."""
loop = get_event_loop()
future = asyncio.run_coroutine_threadsafe(coro, loop)
try:
return future.result(timeout=timeout)
except Exception as e:
logger.error(f"Async operation failed: {type(e).__name__}: {e}")
raise
return asyncio.new_event_loop()
def execute_graph(**kwargs):
"""Execute graph in the shared event loop and wait for completion."""
# Wait for completion to ensure job doesn't exit prematurely
run_async(_execute_graph(**kwargs))
get_event_loop().run_until_complete(_execute_graph(**kwargs))
async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs)
start_time = asyncio.get_event_loop().time()
try:
logger.info(f"Executing recurring job for graph #{args.graph_id}")
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
@@ -144,28 +89,16 @@ async def _execute_graph(**kwargs):
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
)
elapsed = asyncio.get_event_loop().time() - start_time
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
f"(took {elapsed:.2f}s to create and publish)"
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
)
if elapsed > 10:
logger.warning(
f"Graph execution {graph_exec.id} took {elapsed:.2f}s to create/publish - "
f"this is unusually slow and may indicate resource contention"
)
except Exception as e:
elapsed = asyncio.get_event_loop().time() - start_time
logger.error(
f"Error executing graph {args.graph_id} after {elapsed:.2f}s: "
f"{type(e).__name__}: {e}"
)
logger.error(f"Error executing graph {args.graph_id}: {e}")
def cleanup_expired_files():
"""Clean up expired files from cloud storage."""
# Wait for completion
run_async(cleanup_expired_files_async())
get_event_loop().run_until_complete(cleanup_expired_files_async())
# Monitoring functions are now imported from monitoring module
@@ -221,7 +154,7 @@ class NotificationJobInfo(NotificationJobArgs):
class Scheduler(AppService):
scheduler: BackgroundScheduler
scheduler: BlockingScheduler
def __init__(self, register_system_tasks: bool = True):
self.register_system_tasks = register_system_tasks
@@ -234,50 +167,10 @@ class Scheduler(AppService):
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size
async def health_check(self) -> str:
# Thread-safe health check with proper initialization handling
if not hasattr(self, "scheduler"):
raise UnhealthyServiceError("Scheduler is still initializing")
# Check if we're in the middle of cleanup
if self.cleaned_up:
return await super().health_check()
# Normal operation - check if scheduler is running
if not self.scheduler.running:
raise UnhealthyServiceError("Scheduler is not running")
return await super().health_check()
def run_service(self):
load_dotenv()
# Initialize the event loop for async jobs
global _event_loop
_event_loop = asyncio.new_event_loop()
# Use daemon thread since it should die with the main service
global _event_loop_thread
_event_loop_thread = threading.Thread(
target=_event_loop.run_forever, daemon=True, name="SchedulerEventLoop"
)
_event_loop_thread.start()
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
# Configure executors to limit concurrency without skipping jobs
from apscheduler.executors.pool import ThreadPoolExecutor
self.scheduler = BackgroundScheduler(
executors={
"default": ThreadPoolExecutor(
max_workers=self.db_pool_size()
), # Match DB pool size to prevent resource contention
},
job_defaults={
"coalesce": True, # Skip redundant missed jobs - just run the latest
"max_instances": 1000, # Effectively unlimited - never drop executions
"misfire_grace_time": None, # No time limit for missed jobs
},
self.scheduler = BlockingScheduler(
jobstores={
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
engine=create_engine(
@@ -307,10 +200,9 @@ class Scheduler(AppService):
if self.register_system_tasks:
# Notification PROCESS WEEKLY SUMMARY
# Runs every Monday at 9 AM UTC
self.scheduler.add_job(
process_weekly_summary,
CronTrigger.from_crontab("0 9 * * 1"),
CronTrigger.from_crontab("0 * * * *"),
id="process_weekly_summary",
kwargs={},
replace_existing=True,
@@ -358,30 +250,13 @@ class Scheduler(AppService):
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
self.scheduler.start()
# Keep the service running since BackgroundScheduler doesn't block
super().run_service()
def cleanup(self):
super().cleanup()
logger.info("⏳ Shutting down scheduler...")
if self.scheduler:
logger.info("⏳ Shutting down scheduler...")
self.scheduler.shutdown(wait=True)
global _event_loop
if _event_loop:
logger.info("⏳ Closing event loop...")
_event_loop.call_soon_threadsafe(_event_loop.stop)
global _event_loop_thread
if _event_loop_thread:
logger.info("⏳ Waiting for event loop thread to finish...")
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
logger.info("Scheduler cleanup complete.")
self.scheduler.shutdown(wait=False)
@expose
def add_graph_execution_schedule(
@@ -394,18 +269,6 @@ class Scheduler(AppService):
input_credentials: dict[str, CredentialsMetaInput],
name: Optional[str] = None,
) -> GraphExecutionJobInfo:
# Validate the graph before scheduling to prevent runtime failures
# We don't need the return value, just want the validation to run
run_async(
execution_utils.validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
graph_inputs=input_data,
graph_version=graph_version,
graph_credentials_inputs=input_credentials,
)
)
job_args = GraphExecutionJobArgs(
user_id=user_id,
graph_id=graph_id,

View File

@@ -1,9 +1,10 @@
import pytest
from backend.data import db
from backend.executor.scheduler import SchedulerClient
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.clients import get_scheduler_client
from backend.util.service import get_service_client
from backend.util.test import SpinTestServer
@@ -16,7 +17,7 @@ async def test_agent_schedule(server: SpinTestServer):
user_id=test_user.id,
)
scheduler = get_scheduler_client()
scheduler = get_service_client(SchedulerClient)
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0

View File

@@ -4,36 +4,52 @@ import threading
import time
from collections import defaultdict
from concurrent.futures import Future
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from pydantic import BaseModel, JsonValue, ValidationError
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel, JsonValue
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.data.db import prisma
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecutionStats,
GraphExecutionWithNodes,
RedisExecutionEventBus,
)
from backend.data.graph import GraphModel, Node
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.clients import (
get_async_execution_event_bus,
get_async_execution_queue,
get_database_manager_async_client,
get_integration_credentials_store,
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
)
from backend.util.exceptions import GraphValidationError, NotFoundError
from backend.util.exceptions import NotFoundError
from backend.util.logging import TruncatedLogger
from backend.util.mock import MockObject
from backend.util.service import get_service_client
from backend.util.settings import Config
from backend.util.type import convert
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
from backend.integrations.credentials_store import IntegrationCredentialsStore
config = Config()
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
@@ -70,6 +86,51 @@ class LogMetadata(TruncatedLogger):
)
@thread_cached
def get_execution_event_bus() -> RedisExecutionEventBus:
return RedisExecutionEventBus()
@thread_cached
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
@thread_cached
def get_execution_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_execution_queue_config())
client.connect()
return client
@thread_cached
async def get_async_execution_queue() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
from backend.integrations.credentials_store import IntegrationCredentialsStore
return IntegrationCredentialsStore()
@thread_cached
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
return get_service_client(DatabaseManagerClient)
@thread_cached
def get_db_async_client() -> "DatabaseManagerAsyncClient":
from backend.executor import DatabaseManagerAsyncClient
return get_service_client(DatabaseManagerAsyncClient)
# ============ Execution Cost Helpers ============ #
@@ -396,7 +457,7 @@ def validate_exec(
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.warning(error_message)
logger.error(error_message)
return None, error_message
return data, node_block.name
@@ -406,65 +467,47 @@ async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> dict[str, dict[str, str]]:
"""
Checks all credentials for all nodes of the graph and returns structured errors.
Returns:
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
"""
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = node.block
# Find any fields of type CredentialsMetaInput
credentials_fields = block.input_schema.get_credentials_fields()
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue
for field_name, credentials_meta_type in credentials_fields.items():
try:
if (
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
):
credentials_meta = credentials_meta_type.model_validate(
node_input_mask[field_name]
)
elif field_name in node.input_default:
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
# Missing credentials
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
except ValidationError as e:
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
continue
try:
# Fetch the corresponding Credentials and perform sanity checks
credentials = await get_integration_credentials_store().get_creds_by_id(
user_id, credentials_meta.id
if (
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
):
credentials_meta = credentials_meta_type.model_validate(
node_input_mask[field_name]
)
elif field_name in node.input_default:
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
raise ValueError(
f"Credentials absent for {block.name} node #{node.id} "
f"input '{field_name}'"
)
except Exception as e:
# Handle any errors fetching credentials
credential_errors[node.id][
field_name
] = f"Credentials not available: {e}"
continue
# Fetch the corresponding Credentials and perform sanity checks
credentials = await get_integration_credentials_store().get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
credential_errors[node.id][
field_name
] = f"Unknown credentials #{credentials_meta.id}"
continue
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
@@ -475,12 +518,10 @@ async def _validate_node_input_credentials(
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
credential_errors[node.id][
field_name
] = "Invalid credentials: type/provider mismatch"
continue
return credential_errors
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
def make_node_credentials_input_map(
@@ -518,37 +559,7 @@ def make_node_credentials_input_map(
return result
async def validate_graph_with_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> dict[str, dict[str, str]]:
"""
Validate graph including credentials and return structured errors per node.
Returns:
dict[node_id, dict[field_name, error_message]]: Validation errors per node
"""
# Get input validation errors
node_input_errors = GraphModel.validate_graph_get_errors(
graph, for_run=True, nodes_input_masks=nodes_input_masks
)
# Get credential input/availability/validation errors
node_credential_input_errors = await _validate_node_input_credentials(
graph, user_id, nodes_input_masks
)
# Merge credential errors with structural errors
for node_id, field_errors in node_credential_input_errors.items():
if node_id not in node_input_errors:
node_input_errors[node_id] = {}
node_input_errors[node_id].update(field_errors)
return node_input_errors
async def _construct_starting_node_execution_input(
async def construct_node_execution_input(
graph: GraphModel,
user_id: str,
graph_inputs: BlockInput,
@@ -570,17 +581,8 @@ async def _construct_starting_node_execution_input(
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
# Use new validation function that includes credentials
validation_errors = await validate_graph_with_credentials(
graph, user_id, nodes_input_masks
)
n_error_nodes = len(validation_errors)
n_errors = sum(len(errors) for errors in validation_errors.values())
if validation_errors:
raise GraphValidationError(
f"Graph validation failed: {n_errors} issues on {n_error_nodes} nodes",
node_errors=validation_errors,
)
graph.validate_graph(for_run=True, nodes_input_masks=nodes_input_masks)
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
nodes_input = []
for node in graph.starting_nodes:
@@ -615,67 +617,6 @@ async def _construct_starting_node_execution_input(
return nodes_input
async def validate_and_construct_node_execution_input(
graph_id: str,
user_id: str,
graph_inputs: BlockInput,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], dict[str, dict[str, JsonValue]]]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
This centralizes the logic used by both scheduler validation and actual execution.
Args:
graph_id: The ID of the graph to validate/construct.
user_id: The ID of the user.
graph_inputs: The input data for the graph execution.
graph_version: The version of the graph to use.
graph_credentials_inputs: Credentials inputs to use.
nodes_input_masks: Node inputs to use.
Returns:
tuple[GraphModel, list[tuple[str, BlockInput]]]: Graph model and list of tuples for node execution input.
Raises:
NotFoundError: If the graph is not found.
GraphValidationError: If the graph has validation issues.
ValueError: If there are other validation errors.
"""
if prisma.is_connected():
gdb = graph_db
else:
gdb = get_database_manager_async_client()
graph: GraphModel | None = await gdb.get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
nodes_input_masks = _merge_nodes_input_masks(
(
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else {}
),
nodes_input_masks or {},
)
starting_nodes_input = await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
return graph, starting_nodes_input, nodes_input_masks
def _merge_nodes_input_masks(
overrides_map_1: dict[str, dict[str, JsonValue]],
overrides_map_2: dict[str, dict[str, JsonValue]],
@@ -692,6 +633,11 @@ def _merge_nodes_input_masks(
# ============ Execution Queue Helpers ============ #
class CancelExecutionEvent(BaseModel):
graph_exec_id: str
GRAPH_EXECUTION_EXCHANGE = Exchange(
name="graph_execution",
type=ExchangeType.DIRECT,
@@ -709,11 +655,6 @@ GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
)
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
# Graceful shutdown timeout constants
# Agent executions can run for up to 1 day, so we need a graceful shutdown period
# that allows long-running executions to complete naturally
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 24 * 60 * 60 # 1 day to complete active executions
def create_execution_queue_config() -> RabbitMQConfig:
"""
@@ -728,14 +669,13 @@ def create_execution_queue_config() -> RabbitMQConfig:
durable=True,
auto_delete=False,
arguments={
# x-consumer-timeout (1 week)
# x-consumer-timeout (0 = disabled)
# Problem: Default 30-minute consumer timeout kills long-running graph executions
# Original error: "Consumer acknowledgement timed out after 1800000 ms (30 minutes)"
# Solution: Disable consumer timeout entirely - let graphs run indefinitely
# Safety: Heartbeat mechanism now handles dead consumer detection instead
# Use case: Graph executions that take hours to complete (AI model training, etc.)
"x-consumer-timeout": GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
* 1000,
"x-consumer-timeout": 0,
},
)
cancel_queue = Queue(
@@ -752,10 +692,6 @@ def create_execution_queue_config() -> RabbitMQConfig:
)
class CancelExecutionEvent(BaseModel):
graph_exec_id: str
async def stop_graph_execution(
user_id: str,
graph_exec_id: str,
@@ -769,7 +705,7 @@ async def stop_graph_execution(
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
queue_client = await get_async_execution_queue()
db = execution_db if prisma.is_connected() else get_database_manager_async_client()
db = execution_db if prisma.is_connected() else get_db_async_client()
await queue_client.publish_message(
routing_key="",
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
@@ -801,28 +737,51 @@ async def stop_graph_execution(
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
]:
# If the graph is still on the queue, we can prevent them from being executed
# by setting the status to TERMINATED.
graph_exec.status = ExecutionStatus.TERMINATED
await asyncio.gather(
# Update graph execution status
db.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.TERMINATED,
),
# Publish graph execution event
get_async_execution_event_bus().publish(graph_exec),
)
return
break
if graph_exec.status == ExecutionStatus.RUNNING:
await asyncio.sleep(0.1)
raise TimeoutError(
f"Graph execution #{graph_exec_id} will need to take longer than {wait_timeout} seconds to stop. "
f"You can check the status of the execution in the UI or try again later."
)
# Set the termination status if the graph is not stopped after the timeout.
if graph_exec := await db.get_graph_execution_meta(
execution_id=graph_exec_id, user_id=user_id
):
# If the graph is still on the queue, we can prevent them from being executed
# by setting the status to TERMINATED.
node_execs = await db.get_node_executions(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
],
include_exec_data=False,
)
graph_exec.status = ExecutionStatus.TERMINATED
for node_exec in node_execs:
node_exec.status = ExecutionStatus.TERMINATED
await asyncio.gather(
# Update node execution statuses
db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
),
# Publish node execution events
*[
get_async_execution_event_bus().publish(node_exec)
for node_exec in node_execs
],
)
await asyncio.gather(
# Update graph execution status
db.update_graph_execution_stats(
graph_exec_id=graph_exec_id,
status=ExecutionStatus.TERMINATED,
),
# Publish graph execution event
get_async_execution_event_bus().publish(graph_exec),
)
async def add_graph_execution(
@@ -852,62 +811,61 @@ async def add_graph_execution(
ValueError: If the graph is not found or if there are validation errors.
"""
if prisma.is_connected():
gdb = graph_db
edb = execution_db
else:
edb = get_database_manager_async_client()
gdb = get_db_async_client()
edb = get_db_async_client()
graph, starting_nodes_input, nodes_input_masks = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
graph_inputs=inputs or {},
graph_version=graph_version,
graph_credentials_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
)
graph: GraphModel | None = await gdb.get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
nodes_input_masks = _merge_nodes_input_masks(
(
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else {}
),
nodes_input_masks or {},
)
starting_nodes_input = await construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs or {},
nodes_input_masks=nodes_input_masks,
)
graph_exec = await edb.create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
)
graph_exec = None
try:
graph_exec = await edb.create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
)
queue = await get_async_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry()
if nodes_input_masks:
graph_exec_entry.nodes_input_masks = nodes_input_masks
logger.info(
f"Created graph execution #{graph_exec.id} for graph "
f"#{graph_id} with {len(starting_nodes_input)} starting nodes. "
f"Now publishing to execution queue."
)
await queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
bus = get_async_execution_event_bus()
await bus.publish(graph_exec)
return graph_exec
except BaseException as e:
err = str(e) or type(e).__name__
if not graph_exec:
logger.error(f"Unable to execute graph #{graph_id} failed: {err}")
raise
logger.error(
f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {err}"
)
except Exception as e:
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
await edb.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
@@ -915,7 +873,7 @@ async def add_graph_execution(
await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=err),
stats=GraphExecutionStats(error=str(e)),
)
raise
@@ -930,9 +888,13 @@ class ExecutionOutputEntry(BaseModel):
class NodeExecutionProgress:
def __init__(self):
def __init__(
self,
on_done_task: Callable[[str, object], None],
):
self.output: dict[str, list[ExecutionOutputEntry]] = defaultdict(list)
self.tasks: dict[str, Future] = {}
self.on_done_task = on_done_task
self._lock = threading.Lock()
def add_task(self, node_exec_id: str, task: Future):
@@ -972,9 +934,7 @@ class NodeExecutionProgress:
except TimeoutError:
pass
except Exception as e:
logger.error(
f"Task for exec ID {exec_id} failed with error: {e.__class__.__name__} {str(e)}"
)
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
pass
return self.is_done(0)
@@ -992,7 +952,7 @@ class NodeExecutionProgress:
cancelled_ids.append(task_id)
return cancelled_ids
def wait_for_done(self, timeout: float = 5.0):
def wait_for_cancellation(self, timeout: float = 5.0):
"""
Wait for all cancelled tasks to complete cancellation.
@@ -1002,12 +962,9 @@ class NodeExecutionProgress:
start_time = time.time()
while time.time() - start_time < timeout:
while self.pop_output():
pass
if self.is_done():
return
# Check if all tasks are done (either completed or cancelled)
if all(task.done() for task in self.tasks.values()):
return True
time.sleep(0.1) # Small delay to avoid busy waiting
raise TimeoutError(
@@ -1026,7 +983,11 @@ class NodeExecutionProgress:
if self.output[exec_id]:
return False
self.tasks.pop(exec_id)
if task := self.tasks.pop(exec_id):
try:
self.on_done_task(exec_id, task.result())
except Exception as e:
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
return True
def _next_exec(self) -> str | None:

View File

@@ -5,6 +5,7 @@ from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from typing import Optional
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
from pydantic import SecretStr
@@ -182,15 +183,6 @@ zerobounce_credentials = APIKeyCredentials(
expires_at=None,
)
enrichlayer_credentials = APIKeyCredentials(
id="d9fce73a-6c1d-4e8b-ba2e-12a456789def",
provider="enrichlayer",
api_key=SecretStr(settings.secrets.enrichlayer_api_key),
title="Use Credits for Enrichlayer",
expires_at=None,
)
llama_api_credentials = APIKeyCredentials(
id="d44045af-1c33-4833-9e19-752313214de2",
provider="llama_api",
@@ -199,14 +191,6 @@ llama_api_credentials = APIKeyCredentials(
expires_at=None,
)
v0_credentials = APIKeyCredentials(
id="c4e6d1a0-3b5f-4789-a8e2-9b123456789f",
provider="v0",
api_key=SecretStr(settings.secrets.v0_api_key),
title="Use Credits for v0 by Vercel",
expires_at=None,
)
DEFAULT_CREDENTIALS = [
ollama_credentials,
revid_credentials,
@@ -220,7 +204,6 @@ DEFAULT_CREDENTIALS = [
jina_credentials,
unreal_credentials,
open_router_credentials,
enrichlayer_credentials,
fal_credentials,
exa_credentials,
e2b_credentials,
@@ -231,8 +214,6 @@ DEFAULT_CREDENTIALS = [
smartlead_credentials,
zerobounce_credentials,
google_maps_credentials,
llama_api_credentials,
v0_credentials,
]
@@ -248,15 +229,17 @@ class IntegrationCredentialsStore:
return self._locks
@property
@thread_cached
def db_manager(self):
if prisma.is_connected():
from backend.data import user
return user
else:
from backend.util.clients import get_database_manager_async_client
from backend.executor.database import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_database_manager_async_client()
return get_service_client(DatabaseManagerAsyncClient)
# =============== USER-MANAGED CREDENTIALS =============== #
async def add_creds(self, user_id: str, credentials: Credentials) -> None:
@@ -299,8 +282,6 @@ class IntegrationCredentialsStore:
all_credentials.append(unreal_credentials)
if settings.secrets.open_router_api_key:
all_credentials.append(open_router_credentials)
if settings.secrets.enrichlayer_api_key:
all_credentials.append(enrichlayer_credentials)
if settings.secrets.fal_api_key:
all_credentials.append(fal_credentials)
if settings.secrets.exa_api_key:
@@ -382,6 +363,21 @@ class IntegrationCredentialsStore:
# ============== SYSTEM-MANAGED CREDENTIALS ============== #
async def get_ayrshare_profile_key(self, user_id: str) -> SecretStr | None:
"""Get the Ayrshare profile key for a user.
The profile key is used to authenticate API requests to Ayrshare's social media posting service.
See https://www.ayrshare.com/docs/apis/profiles/overview for more details.
Args:
user_id: The ID of the user to get the profile key for
Returns:
The profile key as a SecretStr if set, None otherwise
"""
user_integrations = await self._get_user_integrations(user_id)
return user_integrations.managed_credentials.ayrshare_profile_key
async def set_ayrshare_profile_key(self, user_id: str, profile_key: str) -> None:
"""Set the Ayrshare profile key for a user.

View File

@@ -25,7 +25,6 @@ class ProviderName(str, Enum):
GROQ = "groq"
HTTP = "http"
HUBSPOT = "hubspot"
ENRICHLAYER = "enrichlayer"
IDEOGRAM = "ideogram"
JINA = "jina"
LLAMA_API = "llama_api"
@@ -48,7 +47,6 @@ class ProviderName(str, Enum):
TWITTER = "twitter"
TODOIST = "todoist"
UNREAL_SPEECH = "unreal_speech"
V0 = "v0"
ZEROBOUNCE = "zerobounce"
@classmethod

View File

@@ -8,11 +8,10 @@ from pydantic import BaseModel
from backend.data.block import get_block
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.util.clients import (
get_database_manager_client,
get_notification_manager_client,
)
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.metrics import sentry_capture_error
from backend.util.service import get_service_client
from backend.util.settings import Config
logger = logging.getLogger(__name__)
@@ -41,7 +40,7 @@ class BlockErrorMonitor:
def __init__(self, include_top_blocks: int | None = None):
self.config = config
self.notification_client = get_notification_manager_client()
self.notification_client = get_service_client(NotificationManagerClient)
self.include_top_blocks = (
include_top_blocks
if include_top_blocks is not None
@@ -108,7 +107,7 @@ class BlockErrorMonitor:
) -> dict[str, BlockStatsWithSamples]:
"""Get block execution stats using efficient SQL aggregation."""
result = get_database_manager_client().get_block_error_stats(
result = execution_utils.get_db_client().get_block_error_stats(
start_time, end_time
)
@@ -198,7 +197,7 @@ class BlockErrorMonitor:
) -> list[str]:
"""Get error samples for a specific block - just a few recent ones."""
# Only fetch a small number of recent failed executions for this specific block
executions = get_database_manager_client().get_node_executions(
executions = execution_utils.get_db_client().get_node_executions(
block_ids=[block_id],
statuses=[ExecutionStatus.FAILED],
created_time_gte=start_time,

Some files were not shown because too many files have changed in this diff Show More