Compare commits

...

4 Commits

Author SHA1 Message Date
Swifty
894e3600fb add other specs 2025-08-01 14:21:57 +02:00
Swifty
9de4b09f20 mv to sub dir 2025-08-01 13:19:42 +02:00
Swifty
62e41d409a websocket server running well now 2025-08-01 13:17:45 +02:00
Swifty
9f03e3af47 added websocket service 2025-08-01 11:19:29 +02:00
19 changed files with 7944 additions and 0 deletions

View File

@@ -0,0 +1,802 @@
# DatabaseManager Technical Specification
## Executive Summary
This document provides a complete technical specification for implementing a drop-in replacement for the AutoGPT Platform's DatabaseManager service. The replacement must maintain 100% API compatibility while preserving all functional behaviors, security requirements, and performance characteristics.
## 1. System Overview
### 1.1 Purpose
The DatabaseManager is a centralized service that provides database access for the AutoGPT Platform's executor system. It encapsulates all database operations behind a service interface, enabling distributed execution while maintaining data consistency and security.
### 1.2 Architecture Pattern
- **Service Type**: HTTP-based microservice using FastAPI
- **Communication**: RPC-style over HTTP with JSON serialization
- **Base Class**: Inherits from `AppService` (backend.util.service)
- **Client Classes**: `DatabaseManagerClient` (sync) and `DatabaseManagerAsyncClient` (async)
- **Port**: Configurable via `config.database_api_port`
### 1.3 Critical Requirements
1. **API Compatibility**: All 40+ exposed methods must maintain exact signatures
2. **Type Safety**: Full type preservation across service boundaries
3. **User Isolation**: All operations must respect user_id boundaries
4. **Transaction Support**: Maintain ACID properties for critical operations
5. **Event Publishing**: Maintain Redis event bus integration for real-time updates
## 2. Service Implementation Requirements
### 2.1 Base Service Class
```python
from backend.util.service import AppService, expose
from backend.util.settings import Config
from backend.data import db
import logging
class DatabaseManager(AppService):
"""
REQUIRED: Inherit from AppService to get:
- Automatic endpoint generation via @expose decorator
- Built-in health checks at /health
- Request/response serialization
- Error handling and logging
"""
def run_service(self) -> None:
"""REQUIRED: Initialize database connection before starting service"""
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect()) # CRITICAL: Must connect to database
super().run_service() # Start HTTP server
def cleanup(self):
"""REQUIRED: Clean disconnect on shutdown"""
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect()) # CRITICAL: Must disconnect cleanly
@classmethod
def get_port(cls) -> int:
"""REQUIRED: Return configured port"""
return config.database_api_port
```
### 2.2 Method Exposure Pattern
```python
@staticmethod
def _(f: Callable[P, R], name: str | None = None) -> Callable[Concatenate[object, P], R]:
"""
REQUIRED: Helper to expose methods with proper signatures
- Preserves function name for endpoint generation
- Maintains type information
- Adds 'self' parameter for instance binding
"""
if name is not None:
f.__name__ = name
return cast(Callable[Concatenate[object, P], R], expose(f))
```
### 2.3 Database Connection Management
**REQUIRED: Use Prisma ORM with these exact configurations:**
```python
from prisma import Prisma
prisma = Prisma(
auto_register=True,
http={"timeout": HTTP_TIMEOUT}, # Default: 120 seconds
datasource={"url": DATABASE_URL}
)
# Connection lifecycle
async def connect():
await prisma.connect()
async def disconnect():
await prisma.disconnect()
```
### 2.4 Transaction Support
**REQUIRED: Implement both regular and locked transactions:**
```python
async def transaction(timeout: float | None = None):
"""Regular database transaction"""
async with prisma.tx(timeout=timeout) as tx:
yield tx
async def locked_transaction(key: str, timeout: float | None = None):
"""Transaction with PostgreSQL advisory lock"""
lock_key = zlib.crc32(key.encode("utf-8"))
async with transaction(timeout=timeout) as tx:
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
yield tx
```
## 3. Complete API Specification
### 3.1 Execution Management APIs
#### get_graph_execution
```python
async def get_graph_execution(
user_id: str,
execution_id: str,
*,
include_node_executions: bool = False
) -> GraphExecution | GraphExecutionWithNodes | None
```
**Behavior**:
- Returns execution only if user_id matches
- Optionally includes all node executions
- Returns None if not found or unauthorized
#### get_graph_executions
```python
async def get_graph_executions(
user_id: str,
graph_id: str | None = None,
*,
limit: int = 50,
graph_version: int | None = None,
cursor: str | None = None,
preset_id: str | None = None
) -> tuple[list[GraphExecution], str | None]
```
**Behavior**:
- Paginated results with cursor
- Filter by graph_id, version, or preset_id
- Returns (executions, next_cursor)
#### create_graph_execution
```python
async def create_graph_execution(
graph_id: str,
graph_version: int,
starting_nodes_input: dict[str, dict[str, Any]],
user_id: str,
preset_id: str | None = None
) -> GraphExecutionWithNodes
```
**Behavior**:
- Creates execution with status "QUEUED"
- Initializes all nodes with "PENDING" status
- Publishes creation event to Redis
- Uses locked transaction on graph_id
#### update_graph_execution_start_time
```python
async def update_graph_execution_start_time(
graph_exec_id: str
) -> None
```
**Behavior**:
- Sets start_time to current timestamp
- Only updates if currently NULL
#### update_graph_execution_stats
```python
async def update_graph_execution_stats(
graph_exec_id: str,
status: AgentExecutionStatus | None = None,
stats: dict[str, Any] | None = None
) -> GraphExecution | None
```
**Behavior**:
- Updates status and/or stats atomically
- Sets end_time if status is terminal (COMPLETED/FAILED)
- Publishes update event to Redis
- Returns updated execution
#### get_node_execution
```python
async def get_node_execution(
node_exec_id: str
) -> NodeExecutionResult | None
```
**Behavior**:
- No user_id check (relies on graph execution security)
- Includes all input/output data
#### get_node_executions
```python
async def get_node_executions(
graph_exec_id: str
) -> list[NodeExecutionResult]
```
**Behavior**:
- Returns all node executions for graph
- Ordered by creation time
#### get_latest_node_execution
```python
async def get_latest_node_execution(
graph_exec_id: str,
node_id: str
) -> NodeExecutionResult | None
```
**Behavior**:
- Returns most recent execution of specific node
- Used for retry/rerun scenarios
#### update_node_execution_status
```python
async def update_node_execution_status(
node_exec_id: str,
status: AgentExecutionStatus,
execution_data: dict[str, Any] | None = None,
stats: dict[str, Any] | None = None
) -> NodeExecutionResult
```
**Behavior**:
- Updates status atomically
- Sets end_time for terminal states
- Optionally updates stats/data
- Publishes event to Redis
- Returns updated execution
#### update_node_execution_status_batch
```python
async def update_node_execution_status_batch(
execution_updates: list[NodeExecutionUpdate]
) -> list[NodeExecutionResult]
```
**Behavior**:
- Batch update multiple nodes in single transaction
- Each update can have different status/stats
- Publishes events for all updates
- Returns all updated executions
#### update_node_execution_stats
```python
async def update_node_execution_stats(
node_exec_id: str,
stats: dict[str, Any]
) -> NodeExecutionResult
```
**Behavior**:
- Updates only stats field
- Merges with existing stats
- Does not affect status
#### upsert_execution_input
```python
async def upsert_execution_input(
node_id: str,
graph_exec_id: str,
input_name: str,
input_data: Any,
node_exec_id: str | None = None
) -> tuple[str, BlockInput]
```
**Behavior**:
- Creates or updates input data
- If node_exec_id not provided, creates node execution
- Serializes input_data to JSON
- Returns (node_exec_id, input_object)
#### upsert_execution_output
```python
async def upsert_execution_output(
node_exec_id: str,
output_name: str,
output_data: Any
) -> None
```
**Behavior**:
- Creates or updates output data
- Serializes output_data to JSON
- No return value
#### get_execution_kv_data
```python
async def get_execution_kv_data(
user_id: str,
key: str
) -> Any | None
```
**Behavior**:
- User-scoped key-value storage
- Returns deserialized JSON data
- Returns None if key not found
#### set_execution_kv_data
```python
async def set_execution_kv_data(
user_id: str,
node_exec_id: str,
key: str,
data: Any
) -> Any | None
```
**Behavior**:
- Sets user-scoped key-value data
- Associates with node execution
- Serializes data to JSON
- Returns previous value or None
#### get_block_error_stats
```python
async def get_block_error_stats() -> list[BlockErrorStats]
```
**Behavior**:
- Aggregates error counts by block_id
- Last 7 days of data
- Groups by error type
### 3.2 Graph Management APIs
#### get_node
```python
async def get_node(
node_id: str
) -> AgentNode | None
```
**Behavior**:
- Returns node with block data
- No user_id check (public blocks)
#### get_graph
```python
async def get_graph(
graph_id: str,
version: int | None = None,
user_id: str | None = None,
for_export: bool = False,
include_subgraphs: bool = False
) -> GraphModel | None
```
**Behavior**:
- Returns latest version if version=None
- Checks user_id for private graphs
- for_export=True excludes internal fields
- include_subgraphs=True loads nested graphs
#### get_connected_output_nodes
```python
async def get_connected_output_nodes(
node_id: str,
output_name: str
) -> list[tuple[AgentNode, AgentNodeLink]]
```
**Behavior**:
- Returns downstream nodes connected to output
- Includes link metadata
- Used for execution flow
#### get_graph_metadata
```python
async def get_graph_metadata(
graph_id: str,
user_id: str
) -> GraphMetadata | None
```
**Behavior**:
- Returns graph metadata without full definition
- User must own or have access to graph
### 3.3 Credit System APIs
#### get_credits
```python
async def get_credits(
user_id: str
) -> int
```
**Behavior**:
- Returns current credit balance
- Always non-negative
#### spend_credits
```python
async def spend_credits(
user_id: str,
cost: int,
metadata: UsageTransactionMetadata
) -> int
```
**Behavior**:
- Deducts credits atomically
- Creates transaction record
- Throws InsufficientCredits if balance too low
- Returns new balance
- metadata includes: block_id, node_exec_id, context
### 3.4 User Management APIs
#### get_user_metadata
```python
async def get_user_metadata(
user_id: str
) -> UserMetadata
```
**Behavior**:
- Returns user preferences and settings
- Creates default if not exists
#### update_user_metadata
```python
async def update_user_metadata(
user_id: str,
data: UserMetadataDTO
) -> UserMetadata
```
**Behavior**:
- Partial update of metadata
- Validates against schema
- Returns updated metadata
#### get_user_integrations
```python
async def get_user_integrations(
user_id: str
) -> UserIntegrations
```
**Behavior**:
- Returns OAuth credentials
- Decrypts sensitive data
- Creates empty if not exists
#### update_user_integrations
```python
async def update_user_integrations(
user_id: str,
data: UserIntegrations
) -> None
```
**Behavior**:
- Updates integration credentials
- Encrypts sensitive data
- No return value
### 3.5 User Communication APIs
#### get_active_user_ids_in_timerange
```python
async def get_active_user_ids_in_timerange(
start_time: datetime,
end_time: datetime
) -> list[str]
```
**Behavior**:
- Returns users with graph executions in range
- Used for analytics/notifications
#### get_user_email_by_id
```python
async def get_user_email_by_id(
user_id: str
) -> str | None
```
**Behavior**:
- Returns user's email address
- None if user not found
#### get_user_email_verification
```python
async def get_user_email_verification(
user_id: str
) -> UserEmailVerification
```
**Behavior**:
- Returns email and verification status
- Used for notification filtering
#### get_user_notification_preference
```python
async def get_user_notification_preference(
user_id: str
) -> NotificationPreference
```
**Behavior**:
- Returns notification settings
- Creates default if not exists
### 3.6 Notification APIs
#### create_or_add_to_user_notification_batch
```python
async def create_or_add_to_user_notification_batch(
user_id: str,
notification_type: NotificationType,
notification_data: NotificationEvent
) -> UserNotificationBatchDTO
```
**Behavior**:
- Adds to existing batch or creates new
- Batches by type for efficiency
- Returns updated batch
#### empty_user_notification_batch
```python
async def empty_user_notification_batch(
user_id: str,
notification_type: NotificationType
) -> None
```
**Behavior**:
- Clears all notifications of type
- Used after sending batch
#### get_all_batches_by_type
```python
async def get_all_batches_by_type(
notification_type: NotificationType
) -> list[UserNotificationBatchDTO]
```
**Behavior**:
- Returns all user batches of type
- Used by notification service
#### get_user_notification_batch
```python
async def get_user_notification_batch(
user_id: str,
notification_type: NotificationType
) -> UserNotificationBatchDTO | None
```
**Behavior**:
- Returns user's batch for type
- None if no batch exists
#### get_user_notification_oldest_message_in_batch
```python
async def get_user_notification_oldest_message_in_batch(
user_id: str,
notification_type: NotificationType
) -> NotificationEvent | None
```
**Behavior**:
- Returns oldest notification in batch
- Used for batch timing decisions
## 4. Client Implementation Requirements
### 4.1 Synchronous Client
```python
class DatabaseManagerClient(AppServiceClient):
"""
REQUIRED: Synchronous client that:
- Converts async methods to sync using endpoint_to_sync
- Maintains exact method signatures
- Handles connection pooling
- Implements retry logic
"""
@classmethod
def get_service_type(cls):
return DatabaseManager
# Example method mapping
get_graph_execution = endpoint_to_sync(DatabaseManager.get_graph_execution)
```
### 4.2 Asynchronous Client
```python
class DatabaseManagerAsyncClient(AppServiceClient):
"""
REQUIRED: Async client that:
- Directly references async methods
- No conversion needed
- Shares connection pool
"""
@classmethod
def get_service_type(cls):
return DatabaseManager
# Direct method reference
get_graph_execution = DatabaseManager.get_graph_execution
```
## 5. Data Models
### 5.1 Core Enums
```python
class AgentExecutionStatus(str, Enum):
PENDING = "PENDING"
QUEUED = "QUEUED"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
CANCELED = "CANCELED"
class NotificationType(str, Enum):
SYSTEM = "SYSTEM"
REVIEW = "REVIEW"
EXECUTION = "EXECUTION"
MARKETING = "MARKETING"
```
### 5.2 Key Data Models
All models must exactly match the Prisma schema definitions. Key models include:
- `GraphExecution`: Execution metadata with stats
- `GraphExecutionWithNodes`: Includes all node executions
- `NodeExecutionResult`: Node execution with I/O data
- `GraphModel`: Complete graph definition
- `UserIntegrations`: OAuth credentials
- `UsageTransactionMetadata`: Credit usage context
- `NotificationEvent`: Individual notification data
## 6. Security Requirements
### 6.1 User Isolation
- **CRITICAL**: All user-scoped operations MUST filter by user_id
- Never expose data across user boundaries
- Use database-level row security where possible
### 6.2 Authentication
- Service assumes authentication handled by API gateway
- user_id parameter is trusted after authentication
- No additional auth checks within service
### 6.3 Data Protection
- Encrypt sensitive integration credentials
- Use HMAC for unsubscribe tokens
- Never log sensitive data
## 7. Performance Requirements
### 7.1 Connection Management
- Maintain persistent database connection
- Use connection pooling (default: 10 connections)
- Implement exponential backoff for retries
### 7.2 Query Optimization
- Use indexes for all WHERE clauses
- Batch operations where possible
- Limit default result sets (50 items)
### 7.3 Event Publishing
- Publish events asynchronously
- Don't block on event delivery
- Use fire-and-forget pattern
## 8. Error Handling
### 8.1 Standard Exceptions
```python
class InsufficientCredits(Exception):
"""Raised when user lacks credits"""
class NotFoundError(Exception):
"""Raised when entity not found"""
class AuthorizationError(Exception):
"""Raised when user lacks access"""
```
### 8.2 Error Response Format
```json
{
"error": "error_type",
"message": "Human readable message",
"details": {} // Optional additional context
}
```
## 9. Testing Requirements
### 9.1 Unit Tests
- Test each method in isolation
- Mock database calls
- Verify user_id filtering
### 9.2 Integration Tests
- Test with real database
- Verify transaction boundaries
- Test concurrent operations
### 9.3 Service Tests
- Test HTTP endpoint generation
- Verify serialization/deserialization
- Test error handling
## 10. Implementation Checklist
### Phase 1: Core Service Setup
- [ ] Create DatabaseManager class inheriting from AppService
- [ ] Implement run_service() with database connection
- [ ] Implement cleanup() with proper disconnect
- [ ] Configure port from settings
- [ ] Set up method exposure helper
### Phase 2: Execution APIs (15 methods)
- [ ] get_graph_execution
- [ ] get_graph_executions
- [ ] get_graph_execution_meta
- [ ] create_graph_execution
- [ ] update_graph_execution_start_time
- [ ] update_graph_execution_stats
- [ ] get_node_execution
- [ ] get_node_executions
- [ ] get_latest_node_execution
- [ ] update_node_execution_status
- [ ] update_node_execution_status_batch
- [ ] update_node_execution_stats
- [ ] upsert_execution_input
- [ ] upsert_execution_output
- [ ] get_execution_kv_data
- [ ] set_execution_kv_data
- [ ] get_block_error_stats
### Phase 3: Graph APIs (4 methods)
- [ ] get_node
- [ ] get_graph
- [ ] get_connected_output_nodes
- [ ] get_graph_metadata
### Phase 4: Credit APIs (2 methods)
- [ ] get_credits
- [ ] spend_credits
### Phase 5: User APIs (4 methods)
- [ ] get_user_metadata
- [ ] update_user_metadata
- [ ] get_user_integrations
- [ ] update_user_integrations
### Phase 6: Communication APIs (4 methods)
- [ ] get_active_user_ids_in_timerange
- [ ] get_user_email_by_id
- [ ] get_user_email_verification
- [ ] get_user_notification_preference
### Phase 7: Notification APIs (5 methods)
- [ ] create_or_add_to_user_notification_batch
- [ ] empty_user_notification_batch
- [ ] get_all_batches_by_type
- [ ] get_user_notification_batch
- [ ] get_user_notification_oldest_message_in_batch
### Phase 8: Client Implementation
- [ ] Create DatabaseManagerClient with sync methods
- [ ] Create DatabaseManagerAsyncClient with async methods
- [ ] Test client method generation
- [ ] Verify type preservation
### Phase 9: Integration Testing
- [ ] Test all methods with real database
- [ ] Verify user isolation
- [ ] Test error scenarios
- [ ] Performance testing
- [ ] Event publishing verification
### Phase 10: Deployment Validation
- [ ] Deploy to test environment
- [ ] Run integration test suite
- [ ] Verify backward compatibility
- [ ] Performance benchmarking
- [ ] Production deployment
## 11. Success Criteria
The implementation is successful when:
1. **All 40+ methods** produce identical outputs to the original
2. **Performance** is within 10% of original implementation
3. **All tests** pass without modification
4. **No breaking changes** to any client code
5. **Security boundaries** are maintained
6. **Event publishing** works identically
7. **Error handling** matches original behavior
## 12. Critical Implementation Notes
1. **DO NOT** modify any function signatures
2. **DO NOT** change any return types
3. **DO NOT** add new required parameters
4. **DO NOT** remove any functionality
5. **ALWAYS** maintain user_id isolation
6. **ALWAYS** publish events for state changes
7. **ALWAYS** use transactions for multi-step operations
8. **ALWAYS** handle errors exactly as original
This specification, when implemented correctly, will produce a drop-in replacement for the DatabaseManager that maintains 100% compatibility with the existing system.

View File

@@ -0,0 +1,765 @@
# Notification Service Technical Specification
## Overview
The AutoGPT Platform Notification Service is a RabbitMQ-based asynchronous notification system that handles various types of user notifications including real-time alerts, batched notifications, and scheduled summaries. The service supports email delivery via Postmark and system alerts via Discord.
## Architecture Overview
### Core Components
1. **NotificationManager Service** (`notifications.py`)
- AppService implementation with RabbitMQ integration
- Processes notification queues asynchronously
- Manages batching strategies and delivery timing
- Handles email templating and sending
2. **RabbitMQ Message Broker**
- Multiple queues for different notification strategies
- Dead letter exchange for failed messages
- Topic-based routing for message distribution
3. **Email Sender** (`email.py`)
- Postmark integration for email delivery
- Jinja2 template rendering
- HTML email composition with unsubscribe headers
4. **Database Storage**
- Notification batching tables
- User preference storage
- Email verification tracking
## Service Exposure Mechanism
### AppService Framework
The NotificationManager extends `AppService` which automatically exposes methods decorated with `@expose` as HTTP endpoints:
```python
class NotificationManager(AppService):
@expose
def queue_weekly_summary(self):
# Implementation
@expose
def process_existing_batches(self, notification_types: list[NotificationType]):
# Implementation
@expose
async def discord_system_alert(self, content: str):
# Implementation
```
### Automatic HTTP Endpoint Creation
When the service starts, the AppService base class:
1. Scans for methods with `@expose` decorator
2. Creates FastAPI routes for each exposed method:
- Route path: `/{method_name}`
- HTTP method: POST
- Endpoint handler: Generated via `_create_fastapi_endpoint()`
### Service Client Access
#### NotificationManagerClient
```python
class NotificationManagerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return NotificationManager
# Direct method references (sync)
process_existing_batches = NotificationManager.process_existing_batches
queue_weekly_summary = NotificationManager.queue_weekly_summary
# Async-to-sync conversion
discord_system_alert = endpoint_to_sync(NotificationManager.discord_system_alert)
```
#### Client Usage Pattern
```python
# Get client instance
client = get_service_client(NotificationManagerClient)
# Call exposed methods via HTTP
client.process_existing_batches([NotificationType.AGENT_RUN])
client.queue_weekly_summary()
client.discord_system_alert("System alert message")
```
### HTTP Communication Details
1. **Service URL**: `http://{host}:{notification_service_port}`
- Default port: 8007
- Host: Configurable via settings
2. **Request Format**:
- Method: POST
- Path: `/{method_name}`
- Body: JSON with method parameters
3. **Client Implementation**:
- Uses `httpx` for HTTP requests
- Automatic retry on connection failures
- Configurable timeout (default from api_call_timeout)
### Direct Function Calls
The service also exposes two functions that can be called directly without going through the service client:
```python
# Sync version - used by ExecutionManager
def queue_notification(event: NotificationEventModel) -> NotificationResult
# Async version - used by credit system
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult
```
These functions:
- Connect directly to RabbitMQ
- Publish messages to appropriate queues
- Return success/failure status
- Are NOT exposed via HTTP
## Message Queuing Architecture
### RabbitMQ Configuration
#### Exchanges
```python
NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
```
#### Queues
1. **immediate_notifications**
- Routing Key: `notification.immediate.#`
- Dead Letter: `failed.immediate`
- For: Critical alerts, errors
2. **admin_notifications**
- Routing Key: `notification.admin.#`
- Dead Letter: `failed.admin`
- For: Refund requests, system alerts
3. **summary_notifications**
- Routing Key: `notification.summary.#`
- Dead Letter: `failed.summary`
- For: Daily/weekly summaries
4. **batch_notifications**
- Routing Key: `notification.batch.#`
- Dead Letter: `failed.batch`
- For: Agent runs, batched events
5. **failed_notifications**
- Routing Key: `failed.#`
- For: All failed messages
### Queue Strategies (QueueType enum)
1. **IMMEDIATE**: Send right away (errors, critical notifications)
2. **BATCH**: Batch for configured delay (agent runs)
3. **SUMMARY**: Scheduled digest (daily/weekly summaries)
4. **BACKOFF**: Exponential backoff strategy (defined but not fully implemented)
5. **ADMIN**: Admin-only notifications
## Notification Types
### Enum Values (NotificationType)
```python
AGENT_RUN # Batch strategy, 1 day delay
ZERO_BALANCE # Backoff strategy, 60 min delay
LOW_BALANCE # Immediate strategy
BLOCK_EXECUTION_FAILED # Backoff strategy, 60 min delay
CONTINUOUS_AGENT_ERROR # Backoff strategy, 60 min delay
DAILY_SUMMARY # Summary strategy
WEEKLY_SUMMARY # Summary strategy
MONTHLY_SUMMARY # Summary strategy
REFUND_REQUEST # Admin strategy
REFUND_PROCESSED # Admin strategy
```
## Integration Points
### 1. Scheduler Integration
The scheduler service (`backend.executor.scheduler`) imports monitoring functions that call the NotificationManagerClient:
```python
from backend.monitoring import (
process_existing_batches,
process_weekly_summary,
)
# These are scheduled as cron jobs
```
### 2. Execution Manager Integration
The ExecutionManager directly calls `queue_notification()` for:
- Agent run completions
- Low balance alerts
```python
from backend.notifications.notifications import queue_notification
# Called after graph execution completes
queue_notification(NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(...)
))
```
### 3. Credit System Integration
The credit system uses `queue_notification_async()` for:
- Refund requests
- Refund processed notifications
```python
from backend.notifications.notifications import queue_notification_async
await queue_notification_async(NotificationEventModel(
user_id=user_id,
type=NotificationType.REFUND_REQUEST,
data=RefundRequestData(...)
))
```
### 4. Monitoring Module Wrappers
The monitoring module provides wrapper functions that are used by the scheduler:
```python
# backend/monitoring/notification_monitor.py
def process_existing_batches(**kwargs):
args = NotificationJobArgs(**kwargs)
get_notification_manager_client().process_existing_batches(
args.notification_types
)
def process_weekly_summary(**kwargs):
get_notification_manager_client().queue_weekly_summary()
```
## Data Models
### Base Event Model
```typescript
interface BaseEventModel {
type: NotificationType;
user_id: string;
created_at: string; // ISO datetime with timezone
}
```
### Notification Event Model
```typescript
interface NotificationEventModel<T> extends BaseEventModel {
data: T;
}
```
### Notification Data Types
#### AgentRunData
```typescript
interface AgentRunData {
agent_name: string;
credits_used: number;
execution_time: number;
node_count: number;
graph_id: string;
outputs: Array<Record<string, any>>;
}
```
#### ZeroBalanceData
```typescript
interface ZeroBalanceData {
last_transaction: number;
last_transaction_time: string; // ISO datetime with timezone
top_up_link: string;
}
```
#### LowBalanceData
```typescript
interface LowBalanceData {
agent_name: string;
current_balance: number; // credits (100 = $1)
billing_page_link: string;
shortfall: number;
}
```
#### BlockExecutionFailedData
```typescript
interface BlockExecutionFailedData {
block_name: string;
block_id: string;
error_message: string;
graph_id: string;
node_id: string;
execution_id: string;
}
```
#### ContinuousAgentErrorData
```typescript
interface ContinuousAgentErrorData {
agent_name: string;
error_message: string;
graph_id: string;
execution_id: string;
start_time: string; // ISO datetime with timezone
error_time: string; // ISO datetime with timezone
attempts: number;
}
```
#### Summary Data Types
```typescript
interface BaseSummaryData {
total_credits_used: number;
total_executions: number;
most_used_agent: string;
total_execution_time: number;
successful_runs: number;
failed_runs: number;
average_execution_time: number;
cost_breakdown: Record<string, number>;
}
interface DailySummaryData extends BaseSummaryData {
date: string; // ISO datetime with timezone
}
interface WeeklySummaryData extends BaseSummaryData {
start_date: string; // ISO datetime with timezone
end_date: string; // ISO datetime with timezone
}
```
#### RefundRequestData
```typescript
interface RefundRequestData {
user_id: string;
user_name: string;
user_email: string;
transaction_id: string;
refund_request_id: string;
reason: string;
amount: number;
balance: number;
}
```
### Summary Parameters
```typescript
interface BaseSummaryParams {
start_date: string; // ISO datetime with timezone
end_date: string; // ISO datetime with timezone
}
interface DailySummaryParams extends BaseSummaryParams {
date: string; // ISO datetime with timezone
}
interface WeeklySummaryParams extends BaseSummaryParams {
start_date: string; // ISO datetime with timezone
end_date: string; // ISO datetime with timezone
}
```
## Database Schema
### NotificationEvent Table
```sql
model NotificationEvent {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
UserNotificationBatch UserNotificationBatch? @relation
userNotificationBatchId String?
type NotificationType
data Json
@@index([userNotificationBatchId])
}
```
### UserNotificationBatch Table
```sql
model UserNotificationBatch {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String
User User @relation
type NotificationType
Notifications NotificationEvent[]
@@unique([userId, type])
}
```
## API Methods
### Exposed Service Methods (via HTTP)
#### queue_weekly_summary()
- **HTTP Endpoint**: `POST /queue_weekly_summary`
- **Purpose**: Triggers weekly summary generation for all active users
- **Process**:
1. Runs in background executor
2. Queries users active in last 7 days
3. Queues summary notification for each user
- **Used by**: Scheduler service (via cron)
#### process_existing_batches(notification_types: list[NotificationType])
- **HTTP Endpoint**: `POST /process_existing_batches`
- **Purpose**: Processes aged-out batches for specified notification types
- **Process**:
1. Runs in background executor
2. Retrieves all batches for given types
3. Checks if oldest message exceeds max delay
4. Sends batched email if aged out
5. Clears processed batches
- **Used by**: Scheduler service (via cron)
#### discord_system_alert(content: str)
- **HTTP Endpoint**: `POST /discord_system_alert`
- **Purpose**: Sends system alerts to Discord channel
- **Async**: Yes (converted to sync by client)
- **Used by**: Monitoring services
### Direct Queue Functions (not via HTTP)
#### queue_notification(event: NotificationEventModel) -> NotificationResult
- **Purpose**: Queue a notification (sync version)
- **Used by**: ExecutionManager (same process)
- **Direct RabbitMQ**: Yes
#### queue_notification_async(event: NotificationEventModel) -> NotificationResult
- **Purpose**: Queue a notification (async version)
- **Used by**: Credit system (async context)
- **Direct RabbitMQ**: Yes
## Message Processing Flow
### 1. Message Routing
```python
def get_routing_key(event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
if strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event_type.value}"
elif strategy == QueueType.BATCH:
return f"notification.batch.{event_type.value}"
# ... etc
```
### 2. Queue Processing Methods
#### _process_immediate(message: str) -> bool
1. Parse message to NotificationEventModel
2. Retrieve user email
3. Check user preferences and email verification
4. Send email immediately via EmailSender
5. Return True if successful
#### _process_batch(message: str) -> bool
1. Parse message to NotificationEventModel
2. Add to user's notification batch
3. Check if batch is old enough (based on delay)
4. If aged out:
- Retrieve all batch messages
- Send combined email
- Clear batch
5. Return True if processed or batched
#### _process_summary(message: str) -> bool
1. Parse message to SummaryParamsEventModel
2. Gather summary data (credits, executions, etc.)
- **Note**: Currently returns hardcoded placeholder data
3. Format and send summary email
4. Return True if successful
#### _process_admin_message(message: str) -> bool
1. Parse message
2. Send to configured admin email
3. No user preference checks
4. Return True if successful
## Email Delivery
### EmailSender Class
#### Template Loading
- Base template: `templates/base.html.jinja2`
- Notification templates: `templates/{notification_type}.html.jinja2`
- Subject templates from NotificationTypeOverride
- **Note**: Templates use `.html.jinja2` extension, not just `.html`
#### Email Composition
```python
def send_templated(
notification: NotificationType,
user_email: str,
data: NotificationEventModel | list[NotificationEventModel],
user_unsub_link: str | None = None
)
```
#### Postmark Integration
- API Token: `settings.secrets.postmark_server_api_token`
- Sender Email: `settings.config.postmark_sender_email`
- Headers:
- `List-Unsubscribe-Post: List-Unsubscribe=One-Click`
- `List-Unsubscribe: <{unsubscribe_link}>`
## User Preferences and Permissions
### Email Verification Check
```python
validated_email = get_db().get_user_email_verification(user_id)
```
### Notification Preferences
```python
preferences = get_db().get_user_notification_preference(user_id).preferences
# Returns dict[NotificationType, bool]
```
### Preference Fields in User Model
- `notifyOnAgentRun`
- `notifyOnZeroBalance`
- `notifyOnLowBalance`
- `notifyOnBlockExecutionFailed`
- `notifyOnContinuousAgentError`
- `notifyOnDailySummary`
- `notifyOnWeeklySummary`
- `notifyOnMonthlySummary`
### Unsubscribe Link Generation
```python
def generate_unsubscribe_link(user_id: str) -> str:
# HMAC-SHA256 signed token
# Format: base64(user_id:signature_hex)
# URL: {platform_base_url}/api/email/unsubscribe?token={token}
```
## Batching Logic
### Batch Delays (get_batch_delay)
**Note**: The delay configuration exists for multiple notification types, but only notifications with `QueueType.BATCH` strategy actually use batching. Others use different strategies:
- `AGENT_RUN`: 1 day (Strategy: BATCH - actually uses batching)
- `ZERO_BALANCE`: 60 minutes configured (Strategy: BACKOFF - not batched)
- `LOW_BALANCE`: 60 minutes configured (Strategy: IMMEDIATE - sent immediately)
- `BLOCK_EXECUTION_FAILED`: 60 minutes configured (Strategy: BACKOFF - not batched)
- `CONTINUOUS_AGENT_ERROR`: 60 minutes configured (Strategy: BACKOFF - not batched)
### Batch Processing
1. Messages added to UserNotificationBatch
2. Oldest message timestamp tracked
3. When `oldest_timestamp + delay < now()`:
- Batch is processed
- All messages sent in single email
- Batch cleared
## Service Lifecycle
### Startup
1. Initialize FastAPI app with exposed endpoints
2. Start HTTP server on port 8007
3. Initialize RabbitMQ connection
4. Create/verify exchanges and queues
5. Set up queue consumers
6. Start processing loop
### Main Loop
```python
while self.running:
await self._run_queue(immediate_queue, self._process_immediate, ...)
await self._run_queue(admin_queue, self._process_admin_message, ...)
await self._run_queue(batch_queue, self._process_batch, ...)
await self._run_queue(summary_queue, self._process_summary, ...)
await asyncio.sleep(0.1)
```
### Shutdown
1. Set `running = False`
2. Disconnect RabbitMQ
3. Cleanup resources
## Configuration
### Environment Variables
```python
# Service Configuration
notification_service_port: int = 8007
# Email Configuration
postmark_sender_email: str = "invalid@invalid.com"
refund_notification_email: str = "refund@agpt.co"
# Security
unsubscribe_secret_key: str = ""
# Secrets
postmark_server_api_token: str = ""
postmark_webhook_token: str = ""
discord_bot_token: str = ""
# Platform URLs
platform_base_url: str
frontend_base_url: str
```
## Error Handling
### Message Processing Errors
- Failed messages sent to dead letter queue
- Validation errors logged but don't crash service
- Connection errors trigger retry with `@continuous_retry()`
### RabbitMQ ACK/NACK Protocol
- Success: `message.ack()`
- Failure: `message.reject(requeue=False)`
- Timeout/Queue empty: Continue loop
### HTTP Endpoint Errors
- Wrapped in RemoteCallError for client
- Automatic retry available via client configuration
- Connection failures tracked and logged
## System Integrations
### DatabaseManagerClient
- User email retrieval
- Email verification status
- Notification preferences
- Batch management
- Active user queries
### Discord Integration
- Uses SendDiscordMessageBlock
- Configured via discord_bot_token
- For system alerts only
## Implementation Checklist
1. **Core Service**
- [ ] AppService implementation with @expose decorators
- [ ] FastAPI endpoint generation
- [ ] RabbitMQ connection management
- [ ] Queue consumer setup
- [ ] Message routing logic
2. **Service Client**
- [ ] NotificationManagerClient implementation
- [ ] HTTP client configuration
- [ ] Method mapping to service endpoints
- [ ] Async-to-sync conversions
3. **Message Processing**
- [ ] Parse and validate all notification types
- [ ] Implement all queue strategies
- [ ] Batch management with delays
- [ ] Summary data gathering
4. **Email Delivery**
- [ ] Postmark integration
- [ ] Template loading and rendering
- [ ] Unsubscribe header support
- [ ] HTML email composition
5. **User Management**
- [ ] Preference checking
- [ ] Email verification
- [ ] Unsubscribe link generation
- [ ] Daily limit tracking
6. **Batching System**
- [ ] Database batch operations
- [ ] Age-out checking
- [ ] Batch clearing after send
- [ ] Oldest message tracking
7. **Error Handling**
- [ ] Dead letter queue routing
- [ ] Message rejection on failure
- [ ] Continuous retry wrapper
- [ ] Validation error logging
8. **Scheduled Operations**
- [ ] Weekly summary generation
- [ ] Batch processing triggers
- [ ] Background executor usage
## Security Considerations
1. **Service-to-Service Communication**:
- HTTP endpoints only accessible internally
- No authentication on service endpoints (internal network only)
- Service discovery via host/port configuration
2. **User Security**:
- Email verification required for all user notifications
- Unsubscribe tokens HMAC-signed
- User preferences enforced
3. **Admin Notifications**:
- Separate queue, no user preference checks
- Fixed admin email configuration
## Testing Considerations
1. **Unit Tests**
- Message parsing and validation
- Routing key generation
- Batch delay calculations
- Template rendering
2. **Integration Tests**
- HTTP endpoint accessibility
- Service client method calls
- RabbitMQ message flow
- Database batch operations
- Email sending (mock Postmark)
3. **Load Tests**
- High volume message processing
- Concurrent HTTP requests
- Batch accumulation limits
- Memory usage under load
## Implementation Status Notes
1. **Backoff Strategy**: While `QueueType.BACKOFF` is defined and used by several notification types (ZERO_BALANCE, BLOCK_EXECUTION_FAILED, CONTINUOUS_AGENT_ERROR), the actual exponential backoff processing logic is not implemented. These messages are routed to immediate queue.
2. **Summary Data**: The `_gather_summary_data()` method currently returns hardcoded placeholder values rather than querying actual execution data from the database.
3. **Batch Processing**: Only `AGENT_RUN` notifications actually use batch processing. Other notification types with configured delays use different strategies (IMMEDIATE or BACKOFF).
## Future Enhancements
1. **Additional Channels**
- SMS notifications (not implemented)
- Webhook notifications (not implemented)
- In-app notifications
2. **Advanced Batching**
- Dynamic batch sizes
- Priority-based processing
- Custom delay configurations
3. **Analytics**
- Delivery tracking
- Open/click rates
- Notification effectiveness metrics
4. **Service Improvements**
- Authentication for HTTP endpoints
- Rate limiting per user
- Circuit breaker patterns
- Implement actual backoff processing for BACKOFF strategy
- Implement real summary data gathering

View File

@@ -0,0 +1,474 @@
# AutoGPT Platform Scheduler Technical Specification
## Executive Summary
This document provides a comprehensive technical specification for the AutoGPT Platform Scheduler service. The scheduler is responsible for managing scheduled graph executions, system monitoring tasks, and periodic maintenance operations. This specification is designed to enable a complete reimplementation that maintains 100% compatibility with the existing system.
## Table of Contents
1. [System Architecture](#system-architecture)
2. [Service Implementation](#service-implementation)
3. [Data Models](#data-models)
4. [API Endpoints](#api-endpoints)
5. [Database Schema](#database-schema)
6. [External Dependencies](#external-dependencies)
7. [Authentication & Authorization](#authentication--authorization)
8. [Process Management](#process-management)
9. [Error Handling](#error-handling)
10. [Configuration](#configuration)
11. [Testing Strategy](#testing-strategy)
## System Architecture
### Overview
The scheduler operates as an independent microservice within the AutoGPT platform, implementing the `AppService` base class pattern. It runs on a dedicated port (default: 8003) and exposes HTTP/JSON-RPC endpoints for communication with other services.
### Core Components
1. **Scheduler Service** (`backend/executor/scheduler.py:156`)
- Extends `AppService` base class
- Manages APScheduler instance with multiple jobstores
- Handles lifecycle management and graceful shutdown
2. **Scheduler Client** (`backend/executor/scheduler.py:354`)
- Extends `AppServiceClient` base class
- Provides async/sync method wrappers for RPC calls
- Implements automatic retry and connection pooling
3. **Entry Points**
- Main executable: `backend/scheduler.py`
- Service launcher: `backend/app.py`
## Service Implementation
### Base Service Pattern
```python
class Scheduler(AppService):
scheduler: BlockingScheduler
def __init__(self, register_system_tasks: bool = True):
self.register_system_tasks = register_system_tasks
@classmethod
def get_port(cls) -> int:
return config.execution_scheduler_port # Default: 8003
@classmethod
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size # Default: 3
def run_service(self):
# Initialize scheduler with jobstores
# Register system tasks if enabled
# Start scheduler blocking loop
def cleanup(self):
# Graceful shutdown of scheduler
# Wait=False for immediate termination
```
### Jobstore Configuration
The scheduler uses three distinct jobstores:
1. **EXECUTION** (`Jobstores.EXECUTION.value`)
- Type: SQLAlchemyJobStore
- Table: `apscheduler_jobs`
- Purpose: Graph execution schedules
- Persistence: Required
2. **BATCHED_NOTIFICATIONS** (`Jobstores.BATCHED_NOTIFICATIONS.value`)
- Type: SQLAlchemyJobStore
- Table: `apscheduler_jobs_batched_notifications`
- Purpose: Batched notification processing
- Persistence: Required
3. **WEEKLY_NOTIFICATIONS** (`Jobstores.WEEKLY_NOTIFICATIONS.value`)
- Type: MemoryJobStore
- Purpose: Weekly summary notifications
- Persistence: Not required
### System Tasks
When `register_system_tasks=True`, the following monitoring tasks are registered:
1. **Weekly Summary Processing**
- Job ID: `process_weekly_summary`
- Schedule: `0 * * * *` (hourly)
- Function: `monitoring.process_weekly_summary`
- Jobstore: WEEKLY_NOTIFICATIONS
2. **Late Execution Monitoring**
- Job ID: `report_late_executions`
- Schedule: Interval (config.execution_late_notification_threshold_secs)
- Function: `monitoring.report_late_executions`
- Jobstore: EXECUTION
3. **Block Error Rate Monitoring**
- Job ID: `report_block_error_rates`
- Schedule: Interval (config.block_error_rate_check_interval_secs)
- Function: `monitoring.report_block_error_rates`
- Jobstore: EXECUTION
4. **Cloud Storage Cleanup**
- Job ID: `cleanup_expired_files`
- Schedule: Interval (config.cloud_storage_cleanup_interval_hours * 3600)
- Function: `cleanup_expired_files`
- Jobstore: EXECUTION
## Data Models
### GraphExecutionJobArgs
```python
class GraphExecutionJobArgs(BaseModel):
user_id: str
graph_id: str
graph_version: int
cron: str
input_data: BlockInput
input_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
```
### GraphExecutionJobInfo
```python
class GraphExecutionJobInfo(GraphExecutionJobArgs):
id: str
name: str
next_run_time: str
@staticmethod
def from_db(job_args: GraphExecutionJobArgs, job_obj: JobObj) -> "GraphExecutionJobInfo":
return GraphExecutionJobInfo(
id=job_obj.id,
name=job_obj.name,
next_run_time=job_obj.next_run_time.isoformat(),
**job_args.model_dump(),
)
```
### NotificationJobArgs
```python
class NotificationJobArgs(BaseModel):
notification_types: list[NotificationType]
cron: str
```
### CredentialsMetaInput
```python
class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
id: str
title: Optional[str] = None
provider: CP
type: CT
```
## API Endpoints
All endpoints are exposed via the `@expose` decorator and follow HTTP POST JSON-RPC pattern.
### 1. Add Graph Execution Schedule
**Endpoint**: `/add_graph_execution_schedule`
**Request Body**:
```json
{
"user_id": "string",
"graph_id": "string",
"graph_version": "integer",
"cron": "string (crontab format)",
"input_data": {},
"input_credentials": {},
"name": "string (optional)"
}
```
**Response**: `GraphExecutionJobInfo`
**Behavior**:
- Creates APScheduler job with CronTrigger
- Uses job kwargs to store GraphExecutionJobArgs
- Sets `replace_existing=True` to allow updates
- Returns job info with generated ID and next run time
### 2. Delete Graph Execution Schedule
**Endpoint**: `/delete_graph_execution_schedule`
**Request Body**:
```json
{
"schedule_id": "string",
"user_id": "string"
}
```
**Response**: `GraphExecutionJobInfo`
**Behavior**:
- Validates schedule exists in EXECUTION jobstore
- Verifies user_id matches job's user_id
- Removes job from scheduler
- Returns deleted job info
**Errors**:
- `NotFoundError`: If job doesn't exist
- `NotAuthorizedError`: If user_id doesn't match
### 3. Get Graph Execution Schedules
**Endpoint**: `/get_graph_execution_schedules`
**Request Body**:
```json
{
"graph_id": "string (optional)",
"user_id": "string (optional)"
}
```
**Response**: `list[GraphExecutionJobInfo]`
**Behavior**:
- Retrieves all jobs from EXECUTION jobstore
- Filters by graph_id and/or user_id if provided
- Validates job kwargs as GraphExecutionJobArgs
- Skips invalid jobs (ValidationError)
- Only returns jobs with next_run_time set
### 4. System Task Endpoints
- `/execute_process_existing_batches` - Trigger batch processing
- `/execute_process_weekly_summary` - Trigger weekly summary
- `/execute_report_late_executions` - Trigger late execution report
- `/execute_report_block_error_rates` - Trigger error rate report
- `/execute_cleanup_expired_files` - Trigger file cleanup
### 5. Health Check
**Endpoints**: `/health_check`, `/health_check_async`
**Methods**: POST, GET
**Response**: "OK"
## Database Schema
### APScheduler Tables
The scheduler relies on APScheduler's SQLAlchemy jobstore schema:
1. **apscheduler_jobs**
- id: VARCHAR (PRIMARY KEY)
- next_run_time: FLOAT
- job_state: BLOB/BYTEA (pickled job data)
2. **apscheduler_jobs_batched_notifications**
- Same schema as above
- Separate table for notification jobs
### Database Configuration
- URL extraction from `DIRECT_URL` environment variable
- Schema extraction from URL query parameter
- Connection pooling: `pool_size=db_pool_size()`, `max_overflow=0`
- Metadata schema binding for multi-schema support
## External Dependencies
### Required Services
1. **PostgreSQL Database**
- Connection via `DIRECT_URL` environment variable
- Schema support via URL parameter
- APScheduler job persistence
2. **ExecutionManager** (via execution_utils)
- Function: `add_graph_execution`
- Called by: `execute_graph` job function
- Purpose: Create graph execution entries
3. **NotificationManager** (via monitoring module)
- Functions: `process_existing_batches`, `queue_weekly_summary`
- Purpose: Notification processing
4. **Cloud Storage** (via util.cloud_storage)
- Function: `cleanup_expired_files_async`
- Purpose: File expiration management
### Python Dependencies
```
apscheduler>=3.10.0
sqlalchemy
pydantic>=2.0
httpx
uvicorn
fastapi
python-dotenv
tenacity
```
## Authentication & Authorization
### Service-Level Authentication
- No authentication required between internal services
- Services communicate via trusted internal network
- Host/port configuration via environment variables
### User-Level Authorization
- Authorization check in `delete_graph_execution_schedule`:
- Validates `user_id` matches job's `user_id`
- Raises `NotAuthorizedError` on mismatch
- No authorization for read operations (security consideration)
## Process Management
### Startup Sequence
1. Load environment variables via `dotenv.load_dotenv()`
2. Extract database URL and schema
3. Initialize BlockingScheduler with configured jobstores
4. Register system tasks (if enabled)
5. Add job execution listener
6. Start scheduler (blocking)
### Shutdown Sequence
1. Receive SIGTERM/SIGINT signal
2. Call `cleanup()` method
3. Shutdown scheduler with `wait=False`
4. Terminate process
### Multi-Process Architecture
- Runs as independent process via `AppProcess`
- Started by `run_processes()` in app.py
- Can run in foreground or background mode
- Automatic signal handling for graceful shutdown
## Error Handling
### Job Execution Errors
- Listener on `EVENT_JOB_ERROR` logs failures
- Errors in job functions are caught and logged
- Jobs continue to run on schedule despite failures
### RPC Communication Errors
- Automatic retry via `@conn_retry` decorator
- Configurable retry count and timeout
- Connection pooling with self-healing
### Database Connection Errors
- APScheduler handles reconnection automatically
- Pool exhaustion prevented by `max_overflow=0`
- Connection errors logged but don't crash service
## Configuration
### Environment Variables
- `DIRECT_URL`: PostgreSQL connection string (required)
- `{SERVICE_NAME}_HOST`: Override service host
- Standard logging configuration
### Config Settings (via Config class)
```python
execution_scheduler_port: int = 8003
scheduler_db_pool_size: int = 3
execution_late_notification_threshold_secs: int
block_error_rate_check_interval_secs: int
cloud_storage_cleanup_interval_hours: int
pyro_host: str = "localhost"
pyro_client_comm_timeout: float = 15
pyro_client_comm_retry: int = 3
rpc_client_call_timeout: int = 300
```
## Testing Strategy
### Unit Tests
1. Mock APScheduler for job management tests
2. Mock database connections
3. Test each RPC endpoint independently
4. Verify job serialization/deserialization
### Integration Tests
1. Test with real PostgreSQL instance
2. Verify job persistence across restarts
3. Test concurrent job execution
4. Validate cron expression parsing
### Critical Test Cases
1. **Job Persistence**: Jobs survive scheduler restart
2. **User Isolation**: Users can only delete their own jobs
3. **Concurrent Access**: Multiple clients can add/remove jobs
4. **Error Recovery**: Service recovers from database outages
5. **Resource Cleanup**: No memory/connection leaks
## Implementation Notes
### Key Design Decisions
1. **BlockingScheduler vs AsyncIOScheduler**: Uses BlockingScheduler for simplicity and compatibility with multiprocessing architecture
2. **Job Storage**: All job arguments stored in kwargs, not in job name/id
3. **Separate Jobstores**: Isolation between execution and notification jobs
4. **No Authentication**: Relies on network isolation for security
### Migration Considerations
1. APScheduler job format must be preserved exactly
2. Database schema cannot change without migration
3. RPC protocol must maintain compatibility
4. Environment variables must match existing deployment
### Performance Considerations
1. Database pool size limited to prevent exhaustion
2. No job result storage (fire-and-forget pattern)
3. Minimal logging in hot paths
4. Connection reuse via pooling
## Appendix: Critical Implementation Details
### Event Loop Management
```python
@thread_cached
def get_event_loop():
return asyncio.new_event_loop()
def execute_graph(**kwargs):
get_event_loop().run_until_complete(_execute_graph(**kwargs))
```
### Job Function Execution Context
- Jobs run in scheduler's process space
- Each job gets fresh event loop
- No shared state between job executions
- Exceptions logged but don't affect scheduler
### Cron Expression Format
- Uses standard crontab format via `CronTrigger.from_crontab()`
- Supports: minute hour day month day_of_week
- Special strings: @yearly, @monthly, @weekly, @daily, @hourly
This specification provides all necessary details to reimplement the scheduler service while maintaining 100% compatibility with the existing system. Any deviation from these specifications may result in system incompatibility.

View File

@@ -0,0 +1,85 @@
name: CI
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
env:
CARGO_TERM_COLOR: always
RUSTFLAGS: "-D warnings"
jobs:
test:
name: Test
runs-on: ubuntu-latest
services:
redis:
image: redis:7
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 6379:6379
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- name: Run tests
run: cargo test
env:
REDIS_URL: redis://localhost:6379
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
components: clippy
- uses: Swatinem/rust-cache@v2
- name: Run clippy
run: |
cargo clippy -- \
-D warnings \
-D clippy::unwrap_used \
-D clippy::panic \
-D clippy::unimplemented \
-D clippy::todo
fmt:
name: Format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- name: Check formatting
run: cargo fmt -- --check
bench:
name: Benchmarks
runs-on: ubuntu-latest
services:
redis:
image: redis:7
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 6379:6379
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- name: Build benchmarks
run: cargo bench --no-run
env:
REDIS_URL: redis://localhost:6379

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,60 @@
[package]
name = "websocket"
authors = ["AutoGPT Team"]
description = "WebSocket server for AutoGPT Platform"
version = "0.1.0"
edition = "2021"
[lib]
name = "websocket"
path = "src/lib.rs"
[[bin]]
name = "websocket"
path = "src/main.rs"
[dependencies]
axum = { version = "0.7.5", features = ["ws"] }
jsonwebtoken = "9.3.0"
redis = { version = "0.25.4", features = ["aio", "tokio-comp"] }
serde = { version = "1.0.204", features = ["derive"] }
serde_json = "1.0.120"
tokio = { version = "1.38.1", features = ["rt-multi-thread", "macros", "net", "sync", "time", "io-util"] }
tower-http = { version = "0.5.2", features = ["cors"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
futures = "0.3"
dotenvy = "0.15"
clap = { version = "4.5.4", features = ["derive"] }
toml = "0.8"
[dev-dependencies]
# Load testing and profiling
tokio-console = "0.1"
criterion = { version = "0.5", features = ["async_tokio"] }
pprof = { version = "0.13", features = ["flamegraph", "criterion"] }
# Dependencies for benchmarks
tokio-tungstenite = "0.24"
futures-util = "0.3"
chrono = "0.4"
[[bench]]
name = "websocket_bench"
harness = false
[[example]]
name = "ws_client_example"
required-features = []
[profile.release]
opt-level = 3 # Maximum optimization
lto = true # Enable link-time optimization
codegen-units = 1 # Reduce parallel code generation units to increase optimization
panic = "abort" # Remove panic unwinding to reduce binary size
strip = true # Strip symbols from binary
[profile.bench]
opt-level = 3 # Maximum optimization
lto = true # Enable link-time optimization
codegen-units = 1 # Reduce parallel code generation units to increase optimization
debug = true # Keep debug symbols for profiling

View File

@@ -0,0 +1,412 @@
# WebSocket API Technical Specification
## Overview
This document provides a complete technical specification for the AutoGPT Platform WebSocket API (`ws_api.py`). The WebSocket API provides real-time updates for graph and node execution events, enabling clients to monitor workflow execution progress.
## Architecture Overview
### Core Components
1. **WebSocket Server** (`ws_api.py`)
- FastAPI application with WebSocket endpoint
- Handles client connections and message routing
- Authenticates clients via JWT tokens
- Manages subscriptions to execution events
2. **Connection Manager** (`conn_manager.py`)
- Maintains active WebSocket connections
- Manages channel subscriptions
- Routes execution events to subscribed clients
- Handles connection lifecycle
3. **Event Broadcasting System**
- Redis Pub/Sub based event bus
- Asynchronous event broadcaster
- Execution event propagation from backend services
## API Endpoint
### WebSocket Endpoint
- **URL**: `/ws`
- **Protocol**: WebSocket (ws:// or wss://)
- **Query Parameters**:
- `token` (required when auth enabled): JWT authentication token
## Authentication
### JWT Token Authentication
- **When Required**: When `settings.config.enable_auth` is `True`
- **Token Location**: Query parameter `?token=<JWT_TOKEN>`
- **Token Validation**:
```python
payload = parse_jwt_token(token)
user_id = payload.get("sub")
```
- **JWT Requirements**:
- Algorithm: Configured via `settings.JWT_ALGORITHM`
- Secret Key: Configured via `settings.JWT_SECRET_KEY`
- Audience: Must be "authenticated"
- Claims: Must contain `sub` (user ID)
### Authentication Failures
- **4001**: Missing authentication token
- **4002**: Invalid token (missing user ID)
- **4003**: Invalid token (parsing error or expired)
### No-Auth Mode
- When `settings.config.enable_auth` is `False`
- Uses `DEFAULT_USER_ID` from `backend.data.user`
## Message Protocol
### Message Format
All messages use JSON format with the following structure:
```typescript
interface WSMessage {
method: WSMethod;
data?: Record<string, any> | any[] | string;
success?: boolean;
channel?: string;
error?: string;
}
```
### Message Methods (WSMethod enum)
1. **Client-to-Server Methods**:
- `SUBSCRIBE_GRAPH_EXEC`: Subscribe to specific graph execution
- `SUBSCRIBE_GRAPH_EXECS`: Subscribe to all executions of a graph
- `UNSUBSCRIBE`: Unsubscribe from a channel
- `HEARTBEAT`: Keep-alive ping
2. **Server-to-Client Methods**:
- `GRAPH_EXECUTION_EVENT`: Graph execution status update
- `NODE_EXECUTION_EVENT`: Node execution status update
- `ERROR`: Error message
- `HEARTBEAT`: Keep-alive pong
## Subscription Models
### Subscribe to Specific Graph Execution
```typescript
interface WSSubscribeGraphExecutionRequest {
graph_exec_id: string;
}
```
**Channel Key Format**: `{user_id}|graph_exec#{graph_exec_id}`
### Subscribe to All Graph Executions
```typescript
interface WSSubscribeGraphExecutionsRequest {
graph_id: string;
}
```
**Channel Key Format**: `{user_id}|graph#{graph_id}|executions`
## Event Models
### Graph Execution Event
```typescript
interface GraphExecutionEvent {
event_type: "graph_execution_update";
id: string; // graph_exec_id
user_id: string;
graph_id: string;
graph_version: number;
preset_id?: string;
status: ExecutionStatus;
started_at: string; // ISO datetime
ended_at: string; // ISO datetime
inputs: Record<string, any>;
outputs: Record<string, any>;
stats?: {
cost: number; // cents
duration: number; // seconds
duration_cpu_only: number;
node_exec_time: number;
node_exec_time_cpu_only: number;
node_exec_count: number;
node_error_count: number;
error?: string;
};
}
```
### Node Execution Event
```typescript
interface NodeExecutionEvent {
event_type: "node_execution_update";
user_id: string;
graph_id: string;
graph_version: number;
graph_exec_id: string;
node_exec_id: string;
node_id: string;
block_id: string;
status: ExecutionStatus;
input_data: Record<string, any>;
output_data: Record<string, any>;
add_time: string; // ISO datetime
queue_time?: string; // ISO datetime
start_time?: string; // ISO datetime
end_time?: string; // ISO datetime
}
```
### Execution Status Enum
```typescript
enum ExecutionStatus {
INCOMPLETE = "INCOMPLETE",
QUEUED = "QUEUED",
RUNNING = "RUNNING",
COMPLETED = "COMPLETED",
FAILED = "FAILED"
}
```
## Message Flow Examples
### 1. Subscribe to Graph Execution
```json
// Client → Server
{
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": "exec-123"
}
}
// Server → Client (Success)
{
"method": "subscribe_graph_execution",
"success": true,
"channel": "user-456|graph_exec#exec-123"
}
```
### 2. Receive Execution Updates
```json
// Server → Client (Graph Update)
{
"method": "graph_execution_event",
"channel": "user-456|graph_exec#exec-123",
"data": {
"event_type": "graph_execution_update",
"id": "exec-123",
"user_id": "user-456",
"graph_id": "graph-789",
"status": "RUNNING",
// ... other fields
}
}
// Server → Client (Node Update)
{
"method": "node_execution_event",
"channel": "user-456|graph_exec#exec-123",
"data": {
"event_type": "node_execution_update",
"node_exec_id": "node-exec-111",
"status": "COMPLETED",
// ... other fields
}
}
```
### 3. Heartbeat
```json
// Client → Server
{
"method": "heartbeat",
"data": "ping"
}
// Server → Client
{
"method": "heartbeat",
"data": "pong",
"success": true
}
```
### 4. Error Handling
```json
// Server → Client (Invalid Message)
{
"method": "error",
"success": false,
"error": "Invalid message format. Review the schema and retry"
}
```
## Event Broadcasting Architecture
### Redis Pub/Sub Integration
1. **Event Bus Name**: Configured via `config.execution_event_bus_name`
2. **Channel Pattern**: `{event_bus_name}/{channel_key}`
3. **Event Flow**:
- Execution services publish events to Redis
- Event broadcaster listens to Redis pattern `*`
- Events are routed to WebSocket connections based on subscriptions
### Event Broadcaster
- Runs as continuous async task using `@continuous_retry()` decorator
- Listens to all execution events via `AsyncRedisExecutionEventBus`
- Calls `ConnectionManager.send_execution_update()` for each event
## Connection Lifecycle
### Connection Establishment
1. Client connects to `/ws` endpoint
2. Authentication performed (JWT validation)
3. WebSocket accepted via `manager.connect_socket()`
4. Connection added to active connections set
### Message Processing Loop
1. Receive text message from client
2. Parse and validate as `WSMessage`
3. Route to appropriate handler based on `method`
4. Send response or error back to client
### Connection Termination
1. `WebSocketDisconnect` exception caught
2. `manager.disconnect_socket()` called
3. Connection removed from active connections
4. All subscriptions for that connection removed
## Error Handling
### Validation Errors
- **Invalid Message Format**: Returns error with method "error"
- **Invalid Message Data**: Returns error with specific validation message
- **Unknown Message Type**: Returns error indicating unsupported method
### Connection Errors
- WebSocket disconnections handled gracefully
- Failed event parsing logged but doesn't crash connection
- Handler exceptions logged and connection continues
## Configuration
### Environment Variables
```python
# WebSocket Server Configuration
websocket_server_host: str = "0.0.0.0"
websocket_server_port: int = 8001
# Authentication
enable_auth: bool = True
# CORS
backend_cors_allow_origins: List[str] = []
# Redis Event Bus
execution_event_bus_name: str = "autogpt:execution_event_bus"
# Message Size Limits
max_message_size_limit: int = 512000 # 512KB
```
### Security Headers
- CORS middleware applied with configured origins
- Credentials allowed for authenticated requests
- All methods and headers allowed (configurable)
## Deployment Requirements
### Dependencies
1. **FastAPI**: Web framework with WebSocket support
2. **Redis**: For pub/sub event broadcasting
3. **JWT Libraries**: For token validation
4. **Prisma**: Database ORM (for future graph access validation)
### Process Management
- Implements `AppProcess` interface for service lifecycle
- Runs via `uvicorn` ASGI server
- Graceful shutdown handling in `cleanup()` method
### Concurrent Connections
- No hard limit on WebSocket connections
- Memory usage scales with active connections
- Each connection maintains subscription set
## Implementation Checklist
To implement a compatible WebSocket API:
1. **Authentication**
- [ ] JWT token validation from query parameters
- [ ] Support for no-auth mode with default user ID
- [ ] Proper error codes for auth failures
2. **Message Handling**
- [ ] Parse and validate WSMessage format
- [ ] Implement all client-to-server methods
- [ ] Support all server-to-client event types
- [ ] Proper error responses for invalid messages
3. **Subscription Management**
- [ ] Channel key generation matching exact format
- [ ] Support for both execution and graph-level subscriptions
- [ ] Unsubscribe functionality
- [ ] Clean up subscriptions on disconnect
4. **Event Broadcasting**
- [ ] Listen to Redis pub/sub for execution events
- [ ] Route events to correct subscribed connections
- [ ] Handle both graph and node execution events
- [ ] Maintain event order and completeness
5. **Connection Management**
- [ ] Track active WebSocket connections
- [ ] Handle graceful disconnections
- [ ] Implement heartbeat/keepalive
- [ ] Memory-efficient subscription storage
6. **Configuration**
- [ ] Support all environment variables
- [ ] CORS configuration for allowed origins
- [ ] Configurable host/port binding
- [ ] Redis connection configuration
7. **Error Handling**
- [ ] Graceful handling of malformed messages
- [ ] Logging of errors without dropping connections
- [ ] Specific error messages for debugging
- [ ] Recovery from Redis connection issues
## Testing Considerations
1. **Unit Tests**
- Message parsing and validation
- Channel key generation
- Subscription management logic
2. **Integration Tests**
- Full WebSocket connection flow
- Event broadcasting from Redis
- Multi-client subscription scenarios
- Authentication success/failure cases
3. **Load Tests**
- Many concurrent connections
- High-frequency event broadcasting
- Memory usage under load
- Connection/disconnection cycles
## Security Considerations
1. **Authentication**: JWT tokens transmitted via query parameters (consider upgrading to headers)
2. **Authorization**: Currently no graph-level access validation (commented out in code)
3. **Rate Limiting**: No rate limiting implemented
4. **Message Size**: Limited by `max_message_size_limit` configuration
5. **Input Validation**: All inputs validated via Pydantic models
## Future Enhancements (Currently Commented Out)
1. **Graph Access Validation**: Verify user has read access to subscribed graphs
2. **Message Compression**: For large execution payloads
3. **Batch Updates**: Aggregate multiple events in single message
4. **Selective Field Subscription**: Subscribe to specific fields only

View File

@@ -0,0 +1,93 @@
# WebSocket Server Benchmarks
This directory contains performance benchmarks for the AutoGPT WebSocket server.
## Prerequisites
1. Redis must be running locally or set `REDIS_URL` environment variable:
```bash
docker run -d -p 6379:6379 redis:latest
```
2. Build the project in release mode:
```bash
cargo build --release
```
## Running Benchmarks
Run all benchmarks:
```bash
cargo bench
```
Run specific benchmark group:
```bash
cargo bench connection_establishment
cargo bench subscriptions
cargo bench message_throughput
cargo bench concurrent_connections
cargo bench message_parsing
cargo bench redis_event_processing
```
## Benchmark Categories
### Connection Establishment
Tests the performance of establishing WebSocket connections with different authentication scenarios:
- No authentication
- Valid JWT authentication
- Invalid JWT authentication (connection rejection)
### Subscriptions
Measures the performance of subscription operations:
- Subscribing to graph execution events
- Unsubscribing from channels
### Message Throughput
Tests how many messages the server can process per second with varying message counts (10, 100, 1000).
### Concurrent Connections
Benchmarks the server's ability to handle multiple simultaneous connections (10, 50, 100, 500 clients).
### Message Parsing
Tests JSON parsing performance with different message sizes (100B to 100KB).
### Redis Event Processing
Benchmarks the parsing of execution events received from Redis.
## Profiling
To generate flamegraphs for CPU profiling:
1. Install flamegraph tools:
```bash
cargo install flamegraph
```
2. Run benchmarks with profiling:
```bash
cargo bench --bench websocket_bench -- --profile-time=10
```
## Interpreting Results
- **Throughput**: Higher is better (operations/second or elements/second)
- **Time**: Lower is better (nanoseconds per operation)
- **Error margins**: Look for stable results with low standard deviation
## Optimizing Performance
Based on benchmark results, consider:
1. **Connection pooling** for Redis connections
2. **Message batching** for high-throughput scenarios
3. **Async task tuning** for concurrent connection handling
4. **JSON parsing optimization** using simd-json or other fast parsers
5. **Memory allocation** optimization using arena allocators
## Notes
- Benchmarks create actual WebSocket servers on random ports
- Each benchmark iteration properly cleans up resources
- Results may vary based on system resources and Redis performance

View File

@@ -0,0 +1,406 @@
#![allow(clippy::unwrap_used)] // Benchmarks can panic on setup errors
use axum::{routing::get, Router};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use futures_util::{SinkExt, StreamExt};
use serde_json::json;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::runtime::Runtime;
use tokio_tungstenite::{connect_async, tungstenite::Message};
// Import the actual websocket server components
use websocket::{models, ws_handler, AppState, Config, ConnectionManager, Stats};
// Helper to create a test server
async fn create_test_server(enable_auth: bool) -> (String, tokio::task::JoinHandle<()>) {
// Set environment variables for test config
std::env::set_var("WEBSOCKET_SERVER_HOST", "127.0.0.1");
std::env::set_var("WEBSOCKET_SERVER_PORT", "0");
std::env::set_var("ENABLE_AUTH", enable_auth.to_string());
std::env::set_var("SUPABASE_JWT_SECRET", "test_secret");
std::env::set_var("DEFAULT_USER_ID", "test_user");
if std::env::var("REDIS_URL").is_err() {
std::env::set_var("REDIS_URL", "redis://localhost:6379");
}
let mut config = Config::load(None);
config.port = 0; // Force OS to assign port
let redis_client =
redis::Client::open(config.redis_url.clone()).expect("Failed to connect to Redis");
let stats = Arc::new(Stats::default());
let mgr = Arc::new(ConnectionManager::new(
redis_client,
config.execution_event_bus_name.clone(),
stats.clone(),
));
// Start broadcaster
let mgr_clone = mgr.clone();
tokio::spawn(async move {
mgr_clone.run_broadcaster().await;
});
let state = AppState {
mgr,
config: Arc::new(config),
stats,
};
let app = Router::new()
.route("/ws", get(ws_handler))
.layer(axum::Extension(state));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_url = format!("ws://{addr}");
let server_handle = tokio::spawn(async move {
axum::serve(listener, app.into_make_service())
.await
.unwrap();
});
// Give server time to start
tokio::time::sleep(Duration::from_millis(100)).await;
(server_url, server_handle)
}
// Helper to create a valid JWT token
fn create_jwt_token(user_id: &str) -> String {
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use serde::Serialize;
#[derive(Serialize)]
struct Claims {
sub: String,
aud: Vec<String>,
exp: usize,
}
let claims = Claims {
sub: user_id.to_string(),
aud: vec!["authenticated".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp() as usize,
};
encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(b"test_secret"),
)
.unwrap()
}
// Benchmark connection establishment
fn benchmark_connection_establishment(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("connection_establishment");
group.measurement_time(Duration::from_secs(30));
// Test without auth
group.bench_function("no_auth", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
let (ws_stream, _) = connect_async(&url).await.unwrap();
drop(ws_stream);
server_handle.abort();
});
});
// Test with valid auth
group.bench_function("valid_auth", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(true).await;
let token = create_jwt_token("test_user");
let url = format!("{server_url}/ws?token={token}");
let (ws_stream, _) = connect_async(&url).await.unwrap();
drop(ws_stream);
server_handle.abort();
});
});
// Test with invalid auth
group.bench_function("invalid_auth", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(true).await;
let url = format!("{server_url}/ws?token=invalid");
let result = connect_async(&url).await;
assert!(
result.is_err() || {
if let Ok((mut ws_stream, _)) = result {
// Should receive close frame
matches!(ws_stream.next().await, Some(Ok(Message::Close(_))))
} else {
false
}
}
);
server_handle.abort();
});
});
group.finish();
}
// Benchmark subscription operations
fn benchmark_subscriptions(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("subscriptions");
group.measurement_time(Duration::from_secs(20));
group.bench_function("subscribe_graph_execution", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
let msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": "test_exec_123"
}
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
// Wait for response
if let Some(Ok(Message::Text(response))) = ws_stream.next().await {
let resp: serde_json::Value = serde_json::from_str(&response).unwrap();
assert_eq!(resp["success"], true);
}
server_handle.abort();
});
});
group.bench_function("unsubscribe", |b| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
// First subscribe
let msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": "test_exec_123"
}
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
ws_stream.next().await; // Consume response
let msg = json!({
"method": "unsubscribe",
"data": {
"channel": "test_user|graph_exec#test_exec_123"
}
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
// Wait for response
if let Some(Ok(Message::Text(response))) = ws_stream.next().await {
let resp: serde_json::Value = serde_json::from_str(&response).unwrap();
assert_eq!(resp["success"], true);
}
server_handle.abort();
});
});
group.finish();
}
// Benchmark message throughput
fn benchmark_message_throughput(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("message_throughput");
group.measurement_time(Duration::from_secs(30));
for msg_count in [10, 100, 1000].iter() {
group.throughput(Throughput::Elements(*msg_count as u64));
group.bench_with_input(
BenchmarkId::from_parameter(msg_count),
msg_count,
|b, &msg_count| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
// Send multiple heartbeat messages
for _ in 0..msg_count {
let msg = json!({
"method": "heartbeat",
"data": "ping"
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
}
// Receive all responses
for _ in 0..msg_count {
ws_stream.next().await;
}
server_handle.abort();
});
},
);
}
group.finish();
}
// Benchmark concurrent connections
fn benchmark_concurrent_connections(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let mut group = c.benchmark_group("concurrent_connections");
group.measurement_time(Duration::from_secs(60));
group.sample_size(10);
for num_clients in [100, 500, 1000].iter() {
group.throughput(Throughput::Elements(*num_clients as u64));
group.bench_with_input(
BenchmarkId::from_parameter(num_clients),
num_clients,
|b, &num_clients| {
b.to_async(&rt).iter_with_large_drop(|| async {
let (server_url, server_handle) = create_test_server(false).await;
let url = format!("{server_url}/ws");
// Create multiple concurrent connections
let mut handles = vec![];
for i in 0..num_clients {
let url = url.clone();
let handle = tokio::spawn(async move {
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
// Subscribe to a unique channel
let msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": format!("exec_{}", i)
}
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
ws_stream.next().await; // Wait for response
// Send a heartbeat
let msg = json!({
"method": "heartbeat",
"data": "ping"
});
ws_stream
.send(Message::Text(msg.to_string()))
.await
.unwrap();
ws_stream.next().await; // Wait for response
ws_stream
});
handles.push(handle);
}
// Wait for all connections to complete
for handle in handles {
let _ = handle.await;
}
server_handle.abort();
});
},
);
}
group.finish();
}
// Benchmark message parsing
fn benchmark_message_parsing(c: &mut Criterion) {
let mut group = c.benchmark_group("message_parsing");
// Test different message sizes
for msg_size in [100, 1000, 10000].iter() {
group.throughput(Throughput::Bytes(*msg_size as u64));
group.bench_with_input(
BenchmarkId::new("parse_json", msg_size),
msg_size,
|b, &msg_size| {
let data_str = "x".repeat(msg_size);
let json_msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": data_str
}
});
let json_str = json_msg.to_string();
b.iter(|| {
let _: models::WSMessage = serde_json::from_str(&json_str).unwrap();
});
},
);
}
group.finish();
}
// Benchmark Redis event processing
fn benchmark_redis_event_processing(c: &mut Criterion) {
let mut group = c.benchmark_group("redis_event_processing");
group.bench_function("parse_execution_event", |b| {
let event = json!({
"payload": {
"event_type": "graph_execution_update",
"id": "exec_123",
"graph_id": "graph_456",
"graph_version": 1,
"user_id": "user_789",
"status": "RUNNING",
"started_at": "2024-01-01T00:00:00Z",
"inputs": {"test": "data"},
"outputs": {}
}
});
let event_str = event.to_string();
b.iter(|| {
let _: models::RedisEventWrapper = serde_json::from_str(&event_str).unwrap();
});
});
group.finish();
}
criterion_group!(
benches,
benchmark_connection_establishment,
benchmark_subscriptions,
benchmark_message_throughput,
benchmark_concurrent_connections,
benchmark_message_parsing,
benchmark_redis_event_processing
);
criterion_main!(benches);

View File

@@ -0,0 +1,10 @@
# Clippy configuration for robust error handling
# Set the maximum cognitive complexity allowed
cognitive-complexity-threshold = 30
# Warn on TODO/FIXME comments
allow-dbg-in-tests = false
# Enforce documentation
missing-docs-in-crate-items = true

View File

@@ -0,0 +1,23 @@
# WebSocket API Configuration
# Server settings
host = "0.0.0.0"
port = 8001
# Authentication
enable_auth = true
jwt_secret = "your-super-secret-jwt-token-with-at-least-32-characters-long"
jwt_algorithm = "HS256"
default_user_id = "default"
# Redis configuration
redis_url = "redis://:password@localhost:6379/"
# Event bus
execution_event_bus_name = "execution_event"
# Message size limit (in bytes)
max_message_size_limit = 512000
# CORS allowed origins
backend_cors_allow_origins = ["http://localhost:3000", "https://559f69c159ef.ngrok.app"]

View File

@@ -0,0 +1,75 @@
use futures_util::{SinkExt, StreamExt};
use serde_json::json;
use tokio_tungstenite::{connect_async, tungstenite::Message};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let url = "ws://localhost:8001/ws";
println!("Connecting to {url}");
let (mut ws_stream, _) = connect_async(url).await?;
println!("Connected!");
// Subscribe to a graph execution
let subscribe_msg = json!({
"method": "subscribe_graph_execution",
"data": {
"graph_exec_id": "test_exec_123"
}
});
println!("Sending subscription request...");
ws_stream
.send(Message::Text(subscribe_msg.to_string()))
.await?;
// Wait for response
if let Some(msg) = ws_stream.next().await {
if let Message::Text(text) = msg? {
println!("Received: {text}");
}
}
// Send heartbeat
let heartbeat_msg = json!({
"method": "heartbeat",
"data": "ping"
});
println!("Sending heartbeat...");
ws_stream
.send(Message::Text(heartbeat_msg.to_string()))
.await?;
// Wait for pong
if let Some(msg) = ws_stream.next().await {
if let Message::Text(text) = msg? {
println!("Received: {text}");
}
}
// Unsubscribe
let unsubscribe_msg = json!({
"method": "unsubscribe",
"data": {
"channel": "default|graph_exec#test_exec_123"
}
});
println!("Sending unsubscribe request...");
ws_stream
.send(Message::Text(unsubscribe_msg.to_string()))
.await?;
// Wait for response
if let Some(msg) = ws_stream.next().await {
if let Message::Text(text) = msg? {
println!("Received: {text}");
}
}
println!("Closing connection...");
ws_stream.close(None).await?;
Ok(())
}

View File

@@ -0,0 +1,99 @@
use jsonwebtoken::Algorithm;
use serde::Deserialize;
use std::env;
use std::fs;
use std::path::Path;
use std::str::FromStr;
#[derive(Clone, Debug, Deserialize)]
pub struct Config {
pub host: String,
pub port: u16,
pub enable_auth: bool,
pub jwt_secret: String,
pub jwt_algorithm: Algorithm,
pub execution_event_bus_name: String,
pub redis_url: String,
pub default_user_id: String,
pub max_message_size_limit: usize,
pub backend_cors_allow_origins: Vec<String>,
}
impl Config {
pub fn load(config_path: Option<&Path>) -> Self {
let path = config_path.unwrap_or(Path::new("config.toml"));
let toml_result = fs::read_to_string(path)
.ok()
.and_then(|s| toml::from_str::<Config>(&s).ok());
let mut config = match toml_result {
Some(config) => config,
None => Config {
host: env::var("WEBSOCKET_SERVER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()),
port: env::var("WEBSOCKET_SERVER_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(8001),
enable_auth: env::var("ENABLE_AUTH")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(true),
jwt_secret: env::var("SUPABASE_JWT_SECRET")
.unwrap_or_else(|_| "dummy_secret_for_no_auth".to_string()),
jwt_algorithm: Algorithm::HS256,
execution_event_bus_name: env::var("EXECUTION_EVENT_BUS_NAME")
.unwrap_or_else(|_| "execution_event".to_string()),
redis_url: env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://localhost/".to_string()),
default_user_id: "default".to_string(),
max_message_size_limit: env::var("MAX_MESSAGE_SIZE_LIMIT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(512000),
backend_cors_allow_origins: env::var("BACKEND_CORS_ALLOW_ORIGINS")
.unwrap_or_default()
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect(),
},
};
if let Ok(v) = env::var("WEBSOCKET_SERVER_HOST") {
config.host = v;
}
if let Ok(v) = env::var("WEBSOCKET_SERVER_PORT") {
config.port = v.parse().unwrap_or(8001);
}
if let Ok(v) = env::var("ENABLE_AUTH") {
config.enable_auth = v.parse().unwrap_or(true);
}
if let Ok(v) = env::var("SUPABASE_JWT_SECRET") {
config.jwt_secret = v;
}
if let Ok(v) = env::var("JWT_ALGORITHM") {
config.jwt_algorithm = Algorithm::from_str(&v).unwrap_or(Algorithm::HS256);
}
if let Ok(v) = env::var("EXECUTION_EVENT_BUS_NAME") {
config.execution_event_bus_name = v;
}
if let Ok(v) = env::var("REDIS_URL") {
config.redis_url = v;
}
if let Ok(v) = env::var("DEFAULT_USER_ID") {
config.default_user_id = v;
}
if let Ok(v) = env::var("MAX_MESSAGE_SIZE_LIMIT") {
config.max_message_size_limit = v.parse().unwrap_or(512000);
}
if let Ok(v) = env::var("BACKEND_CORS_ALLOW_ORIGINS") {
config.backend_cors_allow_origins = v
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
}
config
}
}

View File

@@ -0,0 +1,277 @@
use futures::StreamExt;
use redis::Client as RedisClient;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, error, info, warn};
use crate::models::{ExecutionEvent, RedisEventWrapper, WSMessage};
use crate::stats::Stats;
pub struct ConnectionManager {
pub subscribers: RwLock<HashMap<String, HashSet<u64>>>,
pub clients: RwLock<HashMap<u64, (String, mpsc::Sender<String>)>>,
pub client_channels: RwLock<HashMap<u64, HashSet<String>>>,
pub next_id: AtomicU64,
pub redis_client: RedisClient,
pub bus_name: String,
pub stats: Arc<Stats>,
}
impl ConnectionManager {
pub fn new(redis_client: RedisClient, bus_name: String, stats: Arc<Stats>) -> Self {
Self {
subscribers: RwLock::new(HashMap::new()),
clients: RwLock::new(HashMap::new()),
client_channels: RwLock::new(HashMap::new()),
next_id: AtomicU64::new(0),
redis_client,
bus_name,
stats,
}
}
pub async fn run_broadcaster(self: Arc<Self>) {
info!("🚀 Starting Redis event broadcaster");
loop {
match self.run_broadcaster_inner().await {
Ok(_) => {
warn!("⚠️ Event broadcaster stopped unexpectedly, restarting in 5 seconds");
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
Err(e) => {
error!("❌ Event broadcaster error: {}, restarting in 5 seconds", e);
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
}
}
}
async fn run_broadcaster_inner(
self: &Arc<Self>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut pubsub = self.redis_client.get_async_pubsub().await?;
pubsub.psubscribe("*").await?;
info!(
"📡 Listening to all Redis events, filtering for bus: {}",
self.bus_name
);
let mut pubsub_stream = pubsub.on_message();
loop {
let msg = pubsub_stream.next().await;
match msg {
Some(msg) => {
let channel: String = msg.get_channel_name().to_string();
debug!("📨 Received message on Redis channel: {}", channel);
self.stats
.redis_messages_received
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let payload: String = match msg.get_payload() {
Ok(p) => p,
Err(e) => {
warn!("⚠️ Failed to get payload from Redis message: {}", e);
self.stats
.errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
continue;
}
};
// Parse the channel format: execution_event/{user_id}/{graph_id}/{graph_exec_id}
let parts: Vec<&str> = channel.split('/').collect();
// Check if this is an execution event channel
if parts.len() != 4 || parts[0] != self.bus_name {
debug!(
"🚫 Ignoring non-execution event channel: {} (parts: {:?}, bus_name: {})",
channel, parts, self.bus_name
);
self.stats
.redis_messages_ignored
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
continue;
}
let user_id = parts[1];
let graph_id = parts[2];
let graph_exec_id = parts[3];
debug!(
"📥 Received event - user: {}, graph: {}, exec: {}",
user_id, graph_id, graph_exec_id
);
// Parse the wrapped event
let wrapped_event = match RedisEventWrapper::parse(&payload) {
Ok(e) => e,
Err(e) => {
warn!("⚠️ Failed to parse event JSON: {}, payload: {}", e, payload);
self.stats
.errors_json_parse
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.stats
.errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
continue;
}
};
let event = wrapped_event.payload;
debug!("📦 Event received: {:?}", event);
let (method, event_json) = match &event {
ExecutionEvent::GraphExecutionUpdate(graph_event) => {
self.stats
.graph_execution_events
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.stats
.events_received_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
(
"graph_execution_event",
match serde_json::to_value(graph_event) {
Ok(v) => v,
Err(e) => {
error!("❌ Failed to serialize graph event: {}", e);
continue;
}
},
)
}
ExecutionEvent::NodeExecutionUpdate(node_event) => {
self.stats
.node_execution_events
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
self.stats
.events_received_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
(
"node_execution_event",
match serde_json::to_value(node_event) {
Ok(v) => v,
Err(e) => {
error!("❌ Failed to serialize node event: {}", e);
continue;
}
},
)
}
};
// Create the channel keys in the format expected by WebSocket clients
let mut channels_to_notify = Vec::new();
// For both event types, notify the specific execution channel
let exec_channel = format!("{user_id}|graph_exec#{graph_exec_id}");
channels_to_notify.push(exec_channel.clone());
// For graph execution events, also notify the graph executions channel
if matches!(&event, ExecutionEvent::GraphExecutionUpdate(_)) {
let graph_channel = format!("{user_id}|graph#{graph_id}|executions");
channels_to_notify.push(graph_channel);
}
debug!(
"📢 Broadcasting {} event to channels: {:?}",
method, channels_to_notify
);
let subs = self.subscribers.read().await;
// Log current subscriber state
debug!("📊 Current subscribers count: {}", subs.len());
for channel_key in channels_to_notify {
let ws_msg = WSMessage {
method: method.to_string(),
channel: Some(channel_key.clone()),
data: Some(event_json.clone()),
..Default::default()
};
let json_msg = match serde_json::to_string(&ws_msg) {
Ok(j) => {
debug!("📤 Sending WebSocket message: {}", j);
j
}
Err(e) => {
error!("❌ Failed to serialize WebSocket message: {}", e);
continue;
}
};
if let Some(client_ids) = subs.get(&channel_key) {
let clients = self.clients.read().await;
let client_count = client_ids.len();
debug!(
"📣 Broadcasting to {} clients on channel: {}",
client_count, channel_key
);
for &cid in client_ids {
if let Some((user_id, tx)) = clients.get(&cid) {
match tx.try_send(json_msg.clone()) {
Ok(_) => {
debug!(
"✅ Message sent immediately to client {} (user: {})",
cid, user_id
);
self.stats
.messages_sent_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
Err(mpsc::error::TrySendError::Full(_)) => {
// Channel is full, try with a small timeout
let tx_clone = tx.clone();
let msg_clone = json_msg.clone();
let stats_clone = self.stats.clone();
tokio::spawn(async move {
match tokio::time::timeout(
std::time::Duration::from_millis(100),
tx_clone.send(msg_clone),
)
.await {
Ok(Ok(_)) => {
stats_clone
.messages_sent_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
_ => {
stats_clone
.messages_failed_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
});
warn!("⚠️ Channel full for client {} (user: {}), sending async", cid, user_id);
}
Err(mpsc::error::TrySendError::Closed(_)) => {
warn!(
"⚠️ Channel closed for client {} (user: {})",
cid, user_id
);
self.stats
.messages_failed_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
} else {
warn!("⚠️ Client {} not found in clients map", cid);
}
}
} else {
info!("📭 No subscribers for channel: {}", channel_key);
}
}
}
None => {
return Err("❌ Redis pubsub stream ended".into());
}
}
}
}
}

View File

@@ -0,0 +1,442 @@
use axum::extract::ws::{CloseFrame, Message, WebSocket};
use axum::{
extract::{Query, WebSocketUpgrade},
http::HeaderMap,
response::IntoResponse,
Extension,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde_json::{json, Value};
use std::collections::HashMap;
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::connection_manager::ConnectionManager;
use crate::models::{Claims, WSMessage};
use crate::AppState;
// Helper function to safely serialize messages
fn serialize_message(msg: &WSMessage) -> String {
serde_json::to_string(msg).unwrap_or_else(|e| {
error!("❌ Failed to serialize WebSocket message: {}", e);
json!({"method": "error", "success": false, "error": "Internal serialization error"})
.to_string()
})
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
query: Query<HashMap<String, String>>,
_headers: HeaderMap,
Extension(state): Extension<AppState>,
) -> impl IntoResponse {
let token = query.0.get("token").cloned();
let mut user_id = state.config.default_user_id.clone();
let mut auth_error_code: Option<u16> = None;
if state.config.enable_auth {
match token {
Some(token_str) => {
debug!("🔐 Authenticating WebSocket connection");
let mut validation = Validation::new(state.config.jwt_algorithm);
validation.set_audience(&["authenticated"]);
let key = DecodingKey::from_secret(state.config.jwt_secret.as_bytes());
match decode::<Claims>(&token_str, &key, &validation) {
Ok(token_data) => {
user_id = token_data.claims.sub.clone();
debug!("✅ WebSocket authenticated for user: {}", user_id);
}
Err(e) => {
warn!("⚠️ JWT validation failed: {}", e);
auth_error_code = Some(4003);
}
}
}
None => {
warn!("⚠️ Missing authentication token in WebSocket connection");
auth_error_code = Some(4001);
}
}
} else {
debug!("🔓 WebSocket connection without auth (auth disabled)");
}
if let Some(code) = auth_error_code {
error!("❌ WebSocket authentication failed with code: {}", code);
state
.mgr
.stats
.connections_failed_auth
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
state
.mgr
.stats
.connections_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
return ws
.on_upgrade(move |mut socket: WebSocket| async move {
let close_frame = Some(CloseFrame {
code,
reason: "Authentication failed".into(),
});
let _ = socket.send(Message::Close(close_frame)).await;
let _ = socket.close().await;
})
.into_response();
}
debug!("✅ WebSocket connection established for user: {}", user_id);
ws.on_upgrade(move |socket| {
handle_socket(
socket,
user_id,
state.mgr.clone(),
state.config.max_message_size_limit,
)
})
}
async fn update_subscription_stats(mgr: &ConnectionManager, channel: &str, add: bool) {
if add {
mgr.stats
.subscriptions_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
mgr.stats
.subscriptions_active
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let mut channel_stats = mgr.stats.channels_active.write().await;
let count = channel_stats.entry(channel.to_string()).or_insert(0);
*count += 1;
} else {
mgr.stats
.unsubscriptions_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
mgr.stats
.subscriptions_active
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
let mut channel_stats = mgr.stats.channels_active.write().await;
if let Some(count) = channel_stats.get_mut(channel) {
*count = count.saturating_sub(1);
if *count == 0 {
channel_stats.remove(channel);
}
}
}
}
pub async fn handle_socket(
mut socket: WebSocket,
user_id: String,
mgr: std::sync::Arc<ConnectionManager>,
max_size: usize,
) {
let client_id = mgr
.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let (tx, mut rx) = mpsc::channel::<String>(10);
info!("👋 New WebSocket client {} for user: {}", client_id, user_id);
// Update connection stats
mgr.stats
.connections_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
mgr.stats
.connections_active
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
// Update active users
{
let mut active_users = mgr.stats.active_users.write().await;
let count = active_users.entry(user_id.clone()).or_insert(0);
*count += 1;
}
{
let mut clients = mgr.clients.write().await;
clients.insert(client_id, (user_id.clone(), tx));
}
{
let mut client_channels = mgr.client_channels.write().await;
client_channels.insert(client_id, std::collections::HashSet::new());
}
loop {
tokio::select! {
msg = rx.recv() => {
if let Some(msg) = msg {
if socket.send(Message::Text(msg)).await.is_err() {
break;
}
} else {
break;
}
}
incoming = socket.recv() => {
let msg = match incoming {
Some(Ok(msg)) => msg,
_ => break,
};
match msg {
Message::Text(text) => {
if text.len() > max_size {
warn!("⚠️ Message from client {} exceeds size limit: {} > {}", client_id, text.len(), max_size);
let err_resp = serialize_message(&WSMessage {
method: "error".to_string(),
success: Some(false),
error: Some("Message exceeds size limit".to_string()),
..Default::default()
});
if socket.send(Message::Text(err_resp)).await.is_err() {
break;
}
continue;
}
mgr.stats.messages_received_total.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let ws_msg: WSMessage = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
warn!("⚠️ Invalid message format from client {}: {}", client_id, e);
mgr.stats.errors_json_parse.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
mgr.stats.errors_total.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let err_resp = serialize_message(&WSMessage {
method: "error".to_string(),
success: Some(false),
error: Some("Invalid message format. Review the schema and retry".to_string()),
..Default::default()
});
if socket.send(Message::Text(err_resp)).await.is_err() {
break;
}
continue;
}
};
debug!("📥 Received {} message from client {}", ws_msg.method, client_id);
match ws_msg.method.as_str() {
"subscribe_graph_execution" => {
let graph_exec_id = match &ws_msg.data {
Some(Value::Object(map)) => map.get("graph_exec_id").and_then(|v| v.as_str()),
_ => None,
};
let Some(graph_exec_id) = graph_exec_id else {
warn!("⚠️ Missing graph_exec_id in subscribe_graph_execution from client {}", client_id);
let err_resp = json!({"method": "error", "success": false, "error": "Missing graph_exec_id"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
continue;
};
let channel = format!("{user_id}|graph_exec#{graph_exec_id}");
debug!("📌 Client {} subscribing to channel: {}", client_id, channel);
{
let mut subs = mgr.subscribers.write().await;
subs.entry(channel.clone()).or_insert(std::collections::HashSet::new()).insert(client_id);
}
{
let mut chs = mgr.client_channels.write().await;
if let Some(set) = chs.get_mut(&client_id) {
set.insert(channel.clone());
}
}
// Update subscription stats
update_subscription_stats(&mgr, &channel, true).await;
let resp = WSMessage {
method: "subscribe_graph_execution".to_string(),
success: Some(true),
channel: Some(channel),
..Default::default()
};
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
break;
}
}
"subscribe_graph_executions" => {
let graph_id = match &ws_msg.data {
Some(Value::Object(map)) => map.get("graph_id").and_then(|v| v.as_str()),
_ => None,
};
let Some(graph_id) = graph_id else {
let err_resp = json!({"method": "error", "success": false, "error": "Missing graph_id"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
continue;
};
let channel = format!("{user_id}|graph#{graph_id}|executions");
{
let mut subs = mgr.subscribers.write().await;
subs.entry(channel.clone()).or_insert(std::collections::HashSet::new()).insert(client_id);
}
{
let mut chs = mgr.client_channels.write().await;
if let Some(set) = chs.get_mut(&client_id) {
set.insert(channel.clone());
}
}
debug!("📌 Client {} subscribing to channel: {}", client_id, channel);
// Update subscription stats
update_subscription_stats(&mgr, &channel, true).await;
let resp = WSMessage {
method: "subscribe_graph_executions".to_string(),
success: Some(true),
channel: Some(channel),
..Default::default()
};
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
break;
}
}
"unsubscribe" => {
let channel = match &ws_msg.data {
Some(Value::String(s)) => Some(s.as_str()),
Some(Value::Object(map)) => map.get("channel").and_then(|v| v.as_str()),
_ => None,
};
let Some(channel) = channel else {
let err_resp = json!({"method": "error", "success": false, "error": "Missing channel"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
continue;
};
let channel = channel.to_string();
if !channel.starts_with(&format!("{user_id}|")) {
let err_resp = json!({"method": "error", "success": false, "error": "Unauthorized channel"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
continue;
}
{
let mut subs = mgr.subscribers.write().await;
if let Some(set) = subs.get_mut(&channel) {
set.remove(&client_id);
if set.is_empty() {
subs.remove(&channel);
}
}
}
{
let mut chs = mgr.client_channels.write().await;
if let Some(set) = chs.get_mut(&client_id) {
set.remove(&channel);
}
}
// Update subscription stats
update_subscription_stats(&mgr, &channel, false).await;
let resp = WSMessage {
method: "unsubscribe".to_string(),
success: Some(true),
channel: Some(channel),
..Default::default()
};
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
break;
}
}
"heartbeat" => {
if ws_msg.data == Some(Value::String("ping".to_string())) {
let resp = WSMessage {
method: "heartbeat".to_string(),
data: Some(Value::String("pong".to_string())),
success: Some(true),
..Default::default()
};
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
break;
}
} else {
let err_resp = json!({"method": "error", "success": false, "error": "Invalid heartbeat"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
}
}
_ => {
warn!("❓ Unknown method '{}' from client {}", ws_msg.method, client_id);
let err_resp = json!({"method": "error", "success": false, "error": "Unknown method"});
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
break;
}
}
}
}
Message::Close(_) => break,
Message::Ping(_) => {
if socket.send(Message::Pong(vec![])).await.is_err() {
break;
}
}
Message::Pong(_) => {}
_ => {}
}
}
else => break,
}
}
// Cleanup
debug!("👋 WebSocket client {} disconnected, cleaning up", client_id);
// Update connection stats
mgr.stats
.connections_active
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
// Update active users
{
let mut active_users = mgr.stats.active_users.write().await;
if let Some(count) = active_users.get_mut(&user_id) {
*count = count.saturating_sub(1);
if *count == 0 {
active_users.remove(&user_id);
}
}
}
let channels = {
let mut client_channels = mgr.client_channels.write().await;
client_channels.remove(&client_id).unwrap_or_default()
};
{
let mut subs = mgr.subscribers.write().await;
for channel in &channels {
if let Some(set) = subs.get_mut(channel) {
set.remove(&client_id);
if set.is_empty() {
subs.remove(channel);
}
}
}
}
// Update subscription stats for all channels the client was subscribed to
for channel in &channels {
update_subscription_stats(&mgr, channel, false).await;
}
{
let mut clients = mgr.clients.write().await;
clients.remove(&client_id);
}
debug!("✨ Cleanup completed for client {}", client_id);
}

View File

@@ -0,0 +1,26 @@
#![deny(warnings)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
#![deny(clippy::unimplemented)]
#![deny(clippy::todo)]
pub mod config;
pub mod connection_manager;
pub mod handlers;
pub mod models;
pub mod stats;
pub use config::Config;
pub use connection_manager::ConnectionManager;
pub use handlers::ws_handler;
pub use stats::Stats;
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub mgr: Arc<ConnectionManager>,
pub config: Arc<Config>,
pub stats: Arc<Stats>,
}

View File

@@ -0,0 +1,172 @@
use axum::{
body::Body,
http::{header, StatusCode},
response::Response,
routing::get,
Router,
};
use clap::Parser;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::cors::{Any, CorsLayer};
use tracing::{debug, error, info};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use crate::config::Config;
use crate::connection_manager::ConnectionManager;
use crate::handlers::ws_handler;
async fn stats_handler(
axum::Extension(state): axum::Extension<AppState>,
) -> Result<axum::response::Json<stats::StatsSnapshot>, StatusCode> {
let snapshot = state.stats.snapshot().await;
Ok(axum::response::Json(snapshot))
}
async fn prometheus_handler(
axum::Extension(state): axum::Extension<AppState>,
) -> Result<Response, StatusCode> {
let snapshot = state.stats.snapshot().await;
let prometheus_text = state.stats.to_prometheus_format(&snapshot);
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain; version=0.0.4")
.body(Body::from(prometheus_text))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
mod config;
mod connection_manager;
mod handlers;
mod models;
mod stats;
#[derive(Parser, Debug)]
#[command(author, version, about)]
struct Cli {
/// Path to a TOML configuration file
#[arg(short = 'c', long = "config", value_name = "FILE")]
config: Option<std::path::PathBuf>,
}
#[derive(Clone)]
pub struct AppState {
mgr: Arc<ConnectionManager>,
config: Arc<Config>,
stats: Arc<stats::Stats>,
}
#[tokio::main]
async fn main() {
// Initialize tracing
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "websocket=info,tower_http=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
info!("🚀 Starting WebSocket API server");
let cli = Cli::parse();
let config = Arc::new(Config::load(cli.config.as_deref()));
info!(
"⚙️ Configuration loaded - host: {}, port: {}, auth: {}",
config.host, config.port, config.enable_auth
);
let redis_client = match redis::Client::open(config.redis_url.clone()) {
Ok(client) => {
debug!("✅ Redis client created successfully");
client
}
Err(e) => {
error!(
"❌ Failed to create Redis client: {}. Please check REDIS_URL environment variable",
e
);
std::process::exit(1);
}
};
let stats = Arc::new(stats::Stats::default());
let mgr = Arc::new(ConnectionManager::new(
redis_client,
config.execution_event_bus_name.clone(),
stats.clone(),
));
let mgr_clone = mgr.clone();
tokio::spawn(async move {
debug!("📡 Starting event broadcaster task");
mgr_clone.run_broadcaster().await;
});
let state = AppState {
mgr,
config: config.clone(),
stats,
};
let app = Router::new()
.route("/ws", get(ws_handler))
.route("/stats", get(stats_handler))
.route("/metrics", get(prometheus_handler))
.layer(axum::Extension(state));
let cors = if config.backend_cors_allow_origins.is_empty() {
// If no specific origins configured, allow any origin but without credentials
CorsLayer::new()
.allow_methods(Any)
.allow_headers(Any)
.allow_origin(Any)
} else {
// If specific origins configured, allow credentials
CorsLayer::new()
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::PUT,
axum::http::Method::DELETE,
axum::http::Method::OPTIONS,
])
.allow_headers(vec![
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
])
.allow_credentials(true)
.allow_origin(
config
.backend_cors_allow_origins
.iter()
.filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
.collect::<Vec<_>>(),
)
};
let app = app.layer(cors);
let addr = format!("{}:{}", config.host, config.port);
let listener = match TcpListener::bind(&addr).await {
Ok(listener) => {
info!("🎧 WebSocket server listening on: {}", addr);
listener
}
Err(e) => {
error!(
"❌ Failed to bind to {}: {}. Please check if the port is already in use",
addr, e
);
std::process::exit(1);
}
};
info!("✨ WebSocket API server ready to accept connections");
if let Err(e) = axum::serve(listener, app.into_make_service()).await {
error!("💥 Server error: {}", e);
std::process::exit(1);
}
}

View File

@@ -0,0 +1,103 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct WSMessage {
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub success: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub channel: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
#[derive(Deserialize)]
pub struct Claims {
pub sub: String,
}
// Event models moved from events.rs
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "event_type")]
pub enum ExecutionEvent {
#[serde(rename = "graph_execution_update")]
GraphExecutionUpdate(GraphExecutionEvent),
#[serde(rename = "node_execution_update")]
NodeExecutionUpdate(NodeExecutionEvent),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphExecutionEvent {
pub id: String,
pub graph_id: String,
pub graph_version: u32,
pub user_id: String,
pub status: ExecutionStatus,
pub started_at: Option<String>,
pub ended_at: Option<String>,
pub preset_id: Option<String>,
pub stats: Option<ExecutionStats>,
// Keep these as JSON since they vary by graph
pub inputs: Value,
pub outputs: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeExecutionEvent {
pub node_exec_id: String,
pub node_id: String,
pub graph_exec_id: String,
pub graph_id: String,
pub graph_version: u32,
pub user_id: String,
pub block_id: String,
pub status: ExecutionStatus,
pub add_time: String,
pub queue_time: Option<String>,
pub start_time: Option<String>,
pub end_time: Option<String>,
// Keep these as JSON since they vary by node type
pub input_data: Value,
pub output_data: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionStats {
pub cost: f64,
pub duration: f64,
pub duration_cpu_only: f64,
pub error: Option<String>,
pub node_error_count: u32,
pub node_exec_count: u32,
pub node_exec_time: f64,
pub node_exec_time_cpu_only: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ExecutionStatus {
Queued,
Running,
Completed,
Failed,
Incomplete,
Terminated,
}
// Wrapper for the Redis event that includes the payload
#[derive(Debug, Deserialize)]
pub struct RedisEventWrapper {
pub payload: ExecutionEvent,
}
impl RedisEventWrapper {
pub fn parse(json_str: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json_str)
}
}

View File

@@ -0,0 +1,238 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
#[derive(Default)]
pub struct Stats {
// Connection metrics
pub connections_total: AtomicU64,
pub connections_active: AtomicU64,
pub connections_failed_auth: AtomicU64,
// Message metrics
pub messages_received_total: AtomicU64,
pub messages_sent_total: AtomicU64,
pub messages_failed_total: AtomicU64,
// Subscription metrics
pub subscriptions_total: AtomicU64,
pub subscriptions_active: AtomicU64,
pub unsubscriptions_total: AtomicU64,
// Event metrics by type
pub events_received_total: AtomicU64,
pub graph_execution_events: AtomicU64,
pub node_execution_events: AtomicU64,
// Redis metrics
pub redis_messages_received: AtomicU64,
pub redis_messages_ignored: AtomicU64,
// Channel metrics
pub channels_active: RwLock<HashMap<String, usize>>, // channel -> subscriber count
// User metrics
pub active_users: RwLock<HashMap<String, usize>>, // user_id -> connection count
// Error metrics
pub errors_total: AtomicU64,
pub errors_json_parse: AtomicU64,
pub errors_message_size: AtomicU64,
}
#[derive(Serialize, Deserialize)]
pub struct StatsSnapshot {
// Connection metrics
pub connections_total: u64,
pub connections_active: u64,
pub connections_failed_auth: u64,
// Message metrics
pub messages_received_total: u64,
pub messages_sent_total: u64,
pub messages_failed_total: u64,
// Subscription metrics
pub subscriptions_total: u64,
pub subscriptions_active: u64,
pub unsubscriptions_total: u64,
// Event metrics
pub events_received_total: u64,
pub graph_execution_events: u64,
pub node_execution_events: u64,
// Redis metrics
pub redis_messages_received: u64,
pub redis_messages_ignored: u64,
// Channel metrics
pub channels_active_count: usize,
pub total_subscribers: usize,
// User metrics
pub active_users_count: usize,
// Error metrics
pub errors_total: u64,
pub errors_json_parse: u64,
pub errors_message_size: u64,
}
impl Stats {
pub async fn snapshot(&self) -> StatsSnapshot {
// Take read locks for HashMap data - it's ok if this is slightly stale
let channels = self.channels_active.read().await;
let total_subscribers: usize = channels.values().sum();
let channels_active_count = channels.len();
drop(channels); // Release lock early
let users = self.active_users.read().await;
let active_users_count = users.len();
drop(users); // Release lock early
StatsSnapshot {
connections_total: self.connections_total.load(Ordering::Relaxed),
connections_active: self.connections_active.load(Ordering::Relaxed),
connections_failed_auth: self.connections_failed_auth.load(Ordering::Relaxed),
messages_received_total: self.messages_received_total.load(Ordering::Relaxed),
messages_sent_total: self.messages_sent_total.load(Ordering::Relaxed),
messages_failed_total: self.messages_failed_total.load(Ordering::Relaxed),
subscriptions_total: self.subscriptions_total.load(Ordering::Relaxed),
subscriptions_active: self.subscriptions_active.load(Ordering::Relaxed),
unsubscriptions_total: self.unsubscriptions_total.load(Ordering::Relaxed),
events_received_total: self.events_received_total.load(Ordering::Relaxed),
graph_execution_events: self.graph_execution_events.load(Ordering::Relaxed),
node_execution_events: self.node_execution_events.load(Ordering::Relaxed),
redis_messages_received: self.redis_messages_received.load(Ordering::Relaxed),
redis_messages_ignored: self.redis_messages_ignored.load(Ordering::Relaxed),
channels_active_count,
total_subscribers,
active_users_count,
errors_total: self.errors_total.load(Ordering::Relaxed),
errors_json_parse: self.errors_json_parse.load(Ordering::Relaxed),
errors_message_size: self.errors_message_size.load(Ordering::Relaxed),
}
}
pub fn to_prometheus_format(&self, snapshot: &StatsSnapshot) -> String {
let mut output = String::new();
// Connection metrics
output.push_str("# HELP ws_connections_total Total number of WebSocket connections\n");
output.push_str("# TYPE ws_connections_total counter\n");
output.push_str(&format!(
"ws_connections_total {}\n\n",
snapshot.connections_total
));
output.push_str(
"# HELP ws_connections_active Current number of active WebSocket connections\n",
);
output.push_str("# TYPE ws_connections_active gauge\n");
output.push_str(&format!(
"ws_connections_active {}\n\n",
snapshot.connections_active
));
output
.push_str("# HELP ws_connections_failed_auth Total number of failed authentications\n");
output.push_str("# TYPE ws_connections_failed_auth counter\n");
output.push_str(&format!(
"ws_connections_failed_auth {}\n\n",
snapshot.connections_failed_auth
));
// Message metrics
output.push_str(
"# HELP ws_messages_received_total Total number of messages received from clients\n",
);
output.push_str("# TYPE ws_messages_received_total counter\n");
output.push_str(&format!(
"ws_messages_received_total {}\n\n",
snapshot.messages_received_total
));
output.push_str("# HELP ws_messages_sent_total Total number of messages sent to clients\n");
output.push_str("# TYPE ws_messages_sent_total counter\n");
output.push_str(&format!(
"ws_messages_sent_total {}\n\n",
snapshot.messages_sent_total
));
// Subscription metrics
output.push_str("# HELP ws_subscriptions_active Current number of active subscriptions\n");
output.push_str("# TYPE ws_subscriptions_active gauge\n");
output.push_str(&format!(
"ws_subscriptions_active {}\n\n",
snapshot.subscriptions_active
));
// Event metrics
output.push_str(
"# HELP ws_events_received_total Total number of events received from Redis\n",
);
output.push_str("# TYPE ws_events_received_total counter\n");
output.push_str(&format!(
"ws_events_received_total {}\n\n",
snapshot.events_received_total
));
output.push_str(
"# HELP ws_graph_execution_events_total Total number of graph execution events\n",
);
output.push_str("# TYPE ws_graph_execution_events_total counter\n");
output.push_str(&format!(
"ws_graph_execution_events_total {}\n\n",
snapshot.graph_execution_events
));
output.push_str(
"# HELP ws_node_execution_events_total Total number of node execution events\n",
);
output.push_str("# TYPE ws_node_execution_events_total counter\n");
output.push_str(&format!(
"ws_node_execution_events_total {}\n\n",
snapshot.node_execution_events
));
// Channel metrics
output.push_str("# HELP ws_channels_active Number of active channels\n");
output.push_str("# TYPE ws_channels_active gauge\n");
output.push_str(&format!(
"ws_channels_active {}\n\n",
snapshot.channels_active_count
));
output.push_str(
"# HELP ws_total_subscribers Total number of subscribers across all channels\n",
);
output.push_str("# TYPE ws_total_subscribers gauge\n");
output.push_str(&format!(
"ws_total_subscribers {}\n\n",
snapshot.total_subscribers
));
// User metrics
output.push_str("# HELP ws_active_users Number of unique users with active connections\n");
output.push_str("# TYPE ws_active_users gauge\n");
output.push_str(&format!(
"ws_active_users {}\n\n",
snapshot.active_users_count
));
// Error metrics
output.push_str("# HELP ws_errors_total Total number of errors\n");
output.push_str("# TYPE ws_errors_total counter\n");
output.push_str(&format!("ws_errors_total {}\n", snapshot.errors_total));
output
}
}