mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-11 16:18:07 -05:00
Compare commits
4 Commits
master
...
autogpt-rs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
894e3600fb | ||
|
|
9de4b09f20 | ||
|
|
62e41d409a | ||
|
|
9f03e3af47 |
802
autogpt_platform/autogpt-rs/DATABASE_MANAGER.md
Normal file
802
autogpt_platform/autogpt-rs/DATABASE_MANAGER.md
Normal file
@@ -0,0 +1,802 @@
|
||||
# DatabaseManager Technical Specification
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This document provides a complete technical specification for implementing a drop-in replacement for the AutoGPT Platform's DatabaseManager service. The replacement must maintain 100% API compatibility while preserving all functional behaviors, security requirements, and performance characteristics.
|
||||
|
||||
## 1. System Overview
|
||||
|
||||
### 1.1 Purpose
|
||||
The DatabaseManager is a centralized service that provides database access for the AutoGPT Platform's executor system. It encapsulates all database operations behind a service interface, enabling distributed execution while maintaining data consistency and security.
|
||||
|
||||
### 1.2 Architecture Pattern
|
||||
- **Service Type**: HTTP-based microservice using FastAPI
|
||||
- **Communication**: RPC-style over HTTP with JSON serialization
|
||||
- **Base Class**: Inherits from `AppService` (backend.util.service)
|
||||
- **Client Classes**: `DatabaseManagerClient` (sync) and `DatabaseManagerAsyncClient` (async)
|
||||
- **Port**: Configurable via `config.database_api_port`
|
||||
|
||||
### 1.3 Critical Requirements
|
||||
1. **API Compatibility**: All 40+ exposed methods must maintain exact signatures
|
||||
2. **Type Safety**: Full type preservation across service boundaries
|
||||
3. **User Isolation**: All operations must respect user_id boundaries
|
||||
4. **Transaction Support**: Maintain ACID properties for critical operations
|
||||
5. **Event Publishing**: Maintain Redis event bus integration for real-time updates
|
||||
|
||||
## 2. Service Implementation Requirements
|
||||
|
||||
### 2.1 Base Service Class
|
||||
|
||||
```python
|
||||
from backend.util.service import AppService, expose
|
||||
from backend.util.settings import Config
|
||||
from backend.data import db
|
||||
import logging
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
"""
|
||||
REQUIRED: Inherit from AppService to get:
|
||||
- Automatic endpoint generation via @expose decorator
|
||||
- Built-in health checks at /health
|
||||
- Request/response serialization
|
||||
- Error handling and logging
|
||||
"""
|
||||
|
||||
def run_service(self) -> None:
|
||||
"""REQUIRED: Initialize database connection before starting service"""
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect()) # CRITICAL: Must connect to database
|
||||
super().run_service() # Start HTTP server
|
||||
|
||||
def cleanup(self):
|
||||
"""REQUIRED: Clean disconnect on shutdown"""
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect()) # CRITICAL: Must disconnect cleanly
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
"""REQUIRED: Return configured port"""
|
||||
return config.database_api_port
|
||||
```
|
||||
|
||||
### 2.2 Method Exposure Pattern
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
def _(f: Callable[P, R], name: str | None = None) -> Callable[Concatenate[object, P], R]:
|
||||
"""
|
||||
REQUIRED: Helper to expose methods with proper signatures
|
||||
- Preserves function name for endpoint generation
|
||||
- Maintains type information
|
||||
- Adds 'self' parameter for instance binding
|
||||
"""
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
```
|
||||
|
||||
### 2.3 Database Connection Management
|
||||
|
||||
**REQUIRED: Use Prisma ORM with these exact configurations:**
|
||||
|
||||
```python
|
||||
from prisma import Prisma
|
||||
|
||||
prisma = Prisma(
|
||||
auto_register=True,
|
||||
http={"timeout": HTTP_TIMEOUT}, # Default: 120 seconds
|
||||
datasource={"url": DATABASE_URL}
|
||||
)
|
||||
|
||||
# Connection lifecycle
|
||||
async def connect():
|
||||
await prisma.connect()
|
||||
|
||||
async def disconnect():
|
||||
await prisma.disconnect()
|
||||
```
|
||||
|
||||
### 2.4 Transaction Support
|
||||
|
||||
**REQUIRED: Implement both regular and locked transactions:**
|
||||
|
||||
```python
|
||||
async def transaction(timeout: float | None = None):
|
||||
"""Regular database transaction"""
|
||||
async with prisma.tx(timeout=timeout) as tx:
|
||||
yield tx
|
||||
|
||||
async def locked_transaction(key: str, timeout: float | None = None):
|
||||
"""Transaction with PostgreSQL advisory lock"""
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction(timeout=timeout) as tx:
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
yield tx
|
||||
```
|
||||
|
||||
## 3. Complete API Specification
|
||||
|
||||
### 3.1 Execution Management APIs
|
||||
|
||||
#### get_graph_execution
|
||||
```python
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
*,
|
||||
include_node_executions: bool = False
|
||||
) -> GraphExecution | GraphExecutionWithNodes | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns execution only if user_id matches
|
||||
- Optionally includes all node executions
|
||||
- Returns None if not found or unauthorized
|
||||
|
||||
#### get_graph_executions
|
||||
```python
|
||||
async def get_graph_executions(
|
||||
user_id: str,
|
||||
graph_id: str | None = None,
|
||||
*,
|
||||
limit: int = 50,
|
||||
graph_version: int | None = None,
|
||||
cursor: str | None = None,
|
||||
preset_id: str | None = None
|
||||
) -> tuple[list[GraphExecution], str | None]
|
||||
```
|
||||
**Behavior**:
|
||||
- Paginated results with cursor
|
||||
- Filter by graph_id, version, or preset_id
|
||||
- Returns (executions, next_cursor)
|
||||
|
||||
#### create_graph_execution
|
||||
```python
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
starting_nodes_input: dict[str, dict[str, Any]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None
|
||||
) -> GraphExecutionWithNodes
|
||||
```
|
||||
**Behavior**:
|
||||
- Creates execution with status "QUEUED"
|
||||
- Initializes all nodes with "PENDING" status
|
||||
- Publishes creation event to Redis
|
||||
- Uses locked transaction on graph_id
|
||||
|
||||
#### update_graph_execution_start_time
|
||||
```python
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str
|
||||
) -> None
|
||||
```
|
||||
**Behavior**:
|
||||
- Sets start_time to current timestamp
|
||||
- Only updates if currently NULL
|
||||
|
||||
#### update_graph_execution_stats
|
||||
```python
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: AgentExecutionStatus | None = None,
|
||||
stats: dict[str, Any] | None = None
|
||||
) -> GraphExecution | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Updates status and/or stats atomically
|
||||
- Sets end_time if status is terminal (COMPLETED/FAILED)
|
||||
- Publishes update event to Redis
|
||||
- Returns updated execution
|
||||
|
||||
#### get_node_execution
|
||||
```python
|
||||
async def get_node_execution(
|
||||
node_exec_id: str
|
||||
) -> NodeExecutionResult | None
|
||||
```
|
||||
**Behavior**:
|
||||
- No user_id check (relies on graph execution security)
|
||||
- Includes all input/output data
|
||||
|
||||
#### get_node_executions
|
||||
```python
|
||||
async def get_node_executions(
|
||||
graph_exec_id: str
|
||||
) -> list[NodeExecutionResult]
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns all node executions for graph
|
||||
- Ordered by creation time
|
||||
|
||||
#### get_latest_node_execution
|
||||
```python
|
||||
async def get_latest_node_execution(
|
||||
graph_exec_id: str,
|
||||
node_id: str
|
||||
) -> NodeExecutionResult | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns most recent execution of specific node
|
||||
- Used for retry/rerun scenarios
|
||||
|
||||
#### update_node_execution_status
|
||||
```python
|
||||
async def update_node_execution_status(
|
||||
node_exec_id: str,
|
||||
status: AgentExecutionStatus,
|
||||
execution_data: dict[str, Any] | None = None,
|
||||
stats: dict[str, Any] | None = None
|
||||
) -> NodeExecutionResult
|
||||
```
|
||||
**Behavior**:
|
||||
- Updates status atomically
|
||||
- Sets end_time for terminal states
|
||||
- Optionally updates stats/data
|
||||
- Publishes event to Redis
|
||||
- Returns updated execution
|
||||
|
||||
#### update_node_execution_status_batch
|
||||
```python
|
||||
async def update_node_execution_status_batch(
|
||||
execution_updates: list[NodeExecutionUpdate]
|
||||
) -> list[NodeExecutionResult]
|
||||
```
|
||||
**Behavior**:
|
||||
- Batch update multiple nodes in single transaction
|
||||
- Each update can have different status/stats
|
||||
- Publishes events for all updates
|
||||
- Returns all updated executions
|
||||
|
||||
#### update_node_execution_stats
|
||||
```python
|
||||
async def update_node_execution_stats(
|
||||
node_exec_id: str,
|
||||
stats: dict[str, Any]
|
||||
) -> NodeExecutionResult
|
||||
```
|
||||
**Behavior**:
|
||||
- Updates only stats field
|
||||
- Merges with existing stats
|
||||
- Does not affect status
|
||||
|
||||
#### upsert_execution_input
|
||||
```python
|
||||
async def upsert_execution_input(
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
node_exec_id: str | None = None
|
||||
) -> tuple[str, BlockInput]
|
||||
```
|
||||
**Behavior**:
|
||||
- Creates or updates input data
|
||||
- If node_exec_id not provided, creates node execution
|
||||
- Serializes input_data to JSON
|
||||
- Returns (node_exec_id, input_object)
|
||||
|
||||
#### upsert_execution_output
|
||||
```python
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
output_data: Any
|
||||
) -> None
|
||||
```
|
||||
**Behavior**:
|
||||
- Creates or updates output data
|
||||
- Serializes output_data to JSON
|
||||
- No return value
|
||||
|
||||
#### get_execution_kv_data
|
||||
```python
|
||||
async def get_execution_kv_data(
|
||||
user_id: str,
|
||||
key: str
|
||||
) -> Any | None
|
||||
```
|
||||
**Behavior**:
|
||||
- User-scoped key-value storage
|
||||
- Returns deserialized JSON data
|
||||
- Returns None if key not found
|
||||
|
||||
#### set_execution_kv_data
|
||||
```python
|
||||
async def set_execution_kv_data(
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
key: str,
|
||||
data: Any
|
||||
) -> Any | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Sets user-scoped key-value data
|
||||
- Associates with node execution
|
||||
- Serializes data to JSON
|
||||
- Returns previous value or None
|
||||
|
||||
#### get_block_error_stats
|
||||
```python
|
||||
async def get_block_error_stats() -> list[BlockErrorStats]
|
||||
```
|
||||
**Behavior**:
|
||||
- Aggregates error counts by block_id
|
||||
- Last 7 days of data
|
||||
- Groups by error type
|
||||
|
||||
### 3.2 Graph Management APIs
|
||||
|
||||
#### get_node
|
||||
```python
|
||||
async def get_node(
|
||||
node_id: str
|
||||
) -> AgentNode | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns node with block data
|
||||
- No user_id check (public blocks)
|
||||
|
||||
#### get_graph
|
||||
```python
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
include_subgraphs: bool = False
|
||||
) -> GraphModel | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns latest version if version=None
|
||||
- Checks user_id for private graphs
|
||||
- for_export=True excludes internal fields
|
||||
- include_subgraphs=True loads nested graphs
|
||||
|
||||
#### get_connected_output_nodes
|
||||
```python
|
||||
async def get_connected_output_nodes(
|
||||
node_id: str,
|
||||
output_name: str
|
||||
) -> list[tuple[AgentNode, AgentNodeLink]]
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns downstream nodes connected to output
|
||||
- Includes link metadata
|
||||
- Used for execution flow
|
||||
|
||||
#### get_graph_metadata
|
||||
```python
|
||||
async def get_graph_metadata(
|
||||
graph_id: str,
|
||||
user_id: str
|
||||
) -> GraphMetadata | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns graph metadata without full definition
|
||||
- User must own or have access to graph
|
||||
|
||||
### 3.3 Credit System APIs
|
||||
|
||||
#### get_credits
|
||||
```python
|
||||
async def get_credits(
|
||||
user_id: str
|
||||
) -> int
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns current credit balance
|
||||
- Always non-negative
|
||||
|
||||
#### spend_credits
|
||||
```python
|
||||
async def spend_credits(
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata
|
||||
) -> int
|
||||
```
|
||||
**Behavior**:
|
||||
- Deducts credits atomically
|
||||
- Creates transaction record
|
||||
- Throws InsufficientCredits if balance too low
|
||||
- Returns new balance
|
||||
- metadata includes: block_id, node_exec_id, context
|
||||
|
||||
### 3.4 User Management APIs
|
||||
|
||||
#### get_user_metadata
|
||||
```python
|
||||
async def get_user_metadata(
|
||||
user_id: str
|
||||
) -> UserMetadata
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns user preferences and settings
|
||||
- Creates default if not exists
|
||||
|
||||
#### update_user_metadata
|
||||
```python
|
||||
async def update_user_metadata(
|
||||
user_id: str,
|
||||
data: UserMetadataDTO
|
||||
) -> UserMetadata
|
||||
```
|
||||
**Behavior**:
|
||||
- Partial update of metadata
|
||||
- Validates against schema
|
||||
- Returns updated metadata
|
||||
|
||||
#### get_user_integrations
|
||||
```python
|
||||
async def get_user_integrations(
|
||||
user_id: str
|
||||
) -> UserIntegrations
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns OAuth credentials
|
||||
- Decrypts sensitive data
|
||||
- Creates empty if not exists
|
||||
|
||||
#### update_user_integrations
|
||||
```python
|
||||
async def update_user_integrations(
|
||||
user_id: str,
|
||||
data: UserIntegrations
|
||||
) -> None
|
||||
```
|
||||
**Behavior**:
|
||||
- Updates integration credentials
|
||||
- Encrypts sensitive data
|
||||
- No return value
|
||||
|
||||
### 3.5 User Communication APIs
|
||||
|
||||
#### get_active_user_ids_in_timerange
|
||||
```python
|
||||
async def get_active_user_ids_in_timerange(
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
) -> list[str]
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns users with graph executions in range
|
||||
- Used for analytics/notifications
|
||||
|
||||
#### get_user_email_by_id
|
||||
```python
|
||||
async def get_user_email_by_id(
|
||||
user_id: str
|
||||
) -> str | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns user's email address
|
||||
- None if user not found
|
||||
|
||||
#### get_user_email_verification
|
||||
```python
|
||||
async def get_user_email_verification(
|
||||
user_id: str
|
||||
) -> UserEmailVerification
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns email and verification status
|
||||
- Used for notification filtering
|
||||
|
||||
#### get_user_notification_preference
|
||||
```python
|
||||
async def get_user_notification_preference(
|
||||
user_id: str
|
||||
) -> NotificationPreference
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns notification settings
|
||||
- Creates default if not exists
|
||||
|
||||
### 3.6 Notification APIs
|
||||
|
||||
#### create_or_add_to_user_notification_batch
|
||||
```python
|
||||
async def create_or_add_to_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
notification_data: NotificationEvent
|
||||
) -> UserNotificationBatchDTO
|
||||
```
|
||||
**Behavior**:
|
||||
- Adds to existing batch or creates new
|
||||
- Batches by type for efficiency
|
||||
- Returns updated batch
|
||||
|
||||
#### empty_user_notification_batch
|
||||
```python
|
||||
async def empty_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType
|
||||
) -> None
|
||||
```
|
||||
**Behavior**:
|
||||
- Clears all notifications of type
|
||||
- Used after sending batch
|
||||
|
||||
#### get_all_batches_by_type
|
||||
```python
|
||||
async def get_all_batches_by_type(
|
||||
notification_type: NotificationType
|
||||
) -> list[UserNotificationBatchDTO]
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns all user batches of type
|
||||
- Used by notification service
|
||||
|
||||
#### get_user_notification_batch
|
||||
```python
|
||||
async def get_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType
|
||||
) -> UserNotificationBatchDTO | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns user's batch for type
|
||||
- None if no batch exists
|
||||
|
||||
#### get_user_notification_oldest_message_in_batch
|
||||
```python
|
||||
async def get_user_notification_oldest_message_in_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType
|
||||
) -> NotificationEvent | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns oldest notification in batch
|
||||
- Used for batch timing decisions
|
||||
|
||||
## 4. Client Implementation Requirements
|
||||
|
||||
### 4.1 Synchronous Client
|
||||
|
||||
```python
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
"""
|
||||
REQUIRED: Synchronous client that:
|
||||
- Converts async methods to sync using endpoint_to_sync
|
||||
- Maintains exact method signatures
|
||||
- Handles connection pooling
|
||||
- Implements retry logic
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# Example method mapping
|
||||
get_graph_execution = endpoint_to_sync(DatabaseManager.get_graph_execution)
|
||||
```
|
||||
|
||||
### 4.2 Asynchronous Client
|
||||
|
||||
```python
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
"""
|
||||
REQUIRED: Async client that:
|
||||
- Directly references async methods
|
||||
- No conversion needed
|
||||
- Shares connection pool
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# Direct method reference
|
||||
get_graph_execution = DatabaseManager.get_graph_execution
|
||||
```
|
||||
|
||||
## 5. Data Models
|
||||
|
||||
### 5.1 Core Enums
|
||||
|
||||
```python
|
||||
class AgentExecutionStatus(str, Enum):
|
||||
PENDING = "PENDING"
|
||||
QUEUED = "QUEUED"
|
||||
RUNNING = "RUNNING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
SYSTEM = "SYSTEM"
|
||||
REVIEW = "REVIEW"
|
||||
EXECUTION = "EXECUTION"
|
||||
MARKETING = "MARKETING"
|
||||
```
|
||||
|
||||
### 5.2 Key Data Models
|
||||
|
||||
All models must exactly match the Prisma schema definitions. Key models include:
|
||||
|
||||
- `GraphExecution`: Execution metadata with stats
|
||||
- `GraphExecutionWithNodes`: Includes all node executions
|
||||
- `NodeExecutionResult`: Node execution with I/O data
|
||||
- `GraphModel`: Complete graph definition
|
||||
- `UserIntegrations`: OAuth credentials
|
||||
- `UsageTransactionMetadata`: Credit usage context
|
||||
- `NotificationEvent`: Individual notification data
|
||||
|
||||
## 6. Security Requirements
|
||||
|
||||
### 6.1 User Isolation
|
||||
- **CRITICAL**: All user-scoped operations MUST filter by user_id
|
||||
- Never expose data across user boundaries
|
||||
- Use database-level row security where possible
|
||||
|
||||
### 6.2 Authentication
|
||||
- Service assumes authentication handled by API gateway
|
||||
- user_id parameter is trusted after authentication
|
||||
- No additional auth checks within service
|
||||
|
||||
### 6.3 Data Protection
|
||||
- Encrypt sensitive integration credentials
|
||||
- Use HMAC for unsubscribe tokens
|
||||
- Never log sensitive data
|
||||
|
||||
## 7. Performance Requirements
|
||||
|
||||
### 7.1 Connection Management
|
||||
- Maintain persistent database connection
|
||||
- Use connection pooling (default: 10 connections)
|
||||
- Implement exponential backoff for retries
|
||||
|
||||
### 7.2 Query Optimization
|
||||
- Use indexes for all WHERE clauses
|
||||
- Batch operations where possible
|
||||
- Limit default result sets (50 items)
|
||||
|
||||
### 7.3 Event Publishing
|
||||
- Publish events asynchronously
|
||||
- Don't block on event delivery
|
||||
- Use fire-and-forget pattern
|
||||
|
||||
## 8. Error Handling
|
||||
|
||||
### 8.1 Standard Exceptions
|
||||
```python
|
||||
class InsufficientCredits(Exception):
|
||||
"""Raised when user lacks credits"""
|
||||
|
||||
class NotFoundError(Exception):
|
||||
"""Raised when entity not found"""
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Raised when user lacks access"""
|
||||
```
|
||||
|
||||
### 8.2 Error Response Format
|
||||
```json
|
||||
{
|
||||
"error": "error_type",
|
||||
"message": "Human readable message",
|
||||
"details": {} // Optional additional context
|
||||
}
|
||||
```
|
||||
|
||||
## 9. Testing Requirements
|
||||
|
||||
### 9.1 Unit Tests
|
||||
- Test each method in isolation
|
||||
- Mock database calls
|
||||
- Verify user_id filtering
|
||||
|
||||
### 9.2 Integration Tests
|
||||
- Test with real database
|
||||
- Verify transaction boundaries
|
||||
- Test concurrent operations
|
||||
|
||||
### 9.3 Service Tests
|
||||
- Test HTTP endpoint generation
|
||||
- Verify serialization/deserialization
|
||||
- Test error handling
|
||||
|
||||
## 10. Implementation Checklist
|
||||
|
||||
### Phase 1: Core Service Setup
|
||||
- [ ] Create DatabaseManager class inheriting from AppService
|
||||
- [ ] Implement run_service() with database connection
|
||||
- [ ] Implement cleanup() with proper disconnect
|
||||
- [ ] Configure port from settings
|
||||
- [ ] Set up method exposure helper
|
||||
|
||||
### Phase 2: Execution APIs (15 methods)
|
||||
- [ ] get_graph_execution
|
||||
- [ ] get_graph_executions
|
||||
- [ ] get_graph_execution_meta
|
||||
- [ ] create_graph_execution
|
||||
- [ ] update_graph_execution_start_time
|
||||
- [ ] update_graph_execution_stats
|
||||
- [ ] get_node_execution
|
||||
- [ ] get_node_executions
|
||||
- [ ] get_latest_node_execution
|
||||
- [ ] update_node_execution_status
|
||||
- [ ] update_node_execution_status_batch
|
||||
- [ ] update_node_execution_stats
|
||||
- [ ] upsert_execution_input
|
||||
- [ ] upsert_execution_output
|
||||
- [ ] get_execution_kv_data
|
||||
- [ ] set_execution_kv_data
|
||||
- [ ] get_block_error_stats
|
||||
|
||||
### Phase 3: Graph APIs (4 methods)
|
||||
- [ ] get_node
|
||||
- [ ] get_graph
|
||||
- [ ] get_connected_output_nodes
|
||||
- [ ] get_graph_metadata
|
||||
|
||||
### Phase 4: Credit APIs (2 methods)
|
||||
- [ ] get_credits
|
||||
- [ ] spend_credits
|
||||
|
||||
### Phase 5: User APIs (4 methods)
|
||||
- [ ] get_user_metadata
|
||||
- [ ] update_user_metadata
|
||||
- [ ] get_user_integrations
|
||||
- [ ] update_user_integrations
|
||||
|
||||
### Phase 6: Communication APIs (4 methods)
|
||||
- [ ] get_active_user_ids_in_timerange
|
||||
- [ ] get_user_email_by_id
|
||||
- [ ] get_user_email_verification
|
||||
- [ ] get_user_notification_preference
|
||||
|
||||
### Phase 7: Notification APIs (5 methods)
|
||||
- [ ] create_or_add_to_user_notification_batch
|
||||
- [ ] empty_user_notification_batch
|
||||
- [ ] get_all_batches_by_type
|
||||
- [ ] get_user_notification_batch
|
||||
- [ ] get_user_notification_oldest_message_in_batch
|
||||
|
||||
### Phase 8: Client Implementation
|
||||
- [ ] Create DatabaseManagerClient with sync methods
|
||||
- [ ] Create DatabaseManagerAsyncClient with async methods
|
||||
- [ ] Test client method generation
|
||||
- [ ] Verify type preservation
|
||||
|
||||
### Phase 9: Integration Testing
|
||||
- [ ] Test all methods with real database
|
||||
- [ ] Verify user isolation
|
||||
- [ ] Test error scenarios
|
||||
- [ ] Performance testing
|
||||
- [ ] Event publishing verification
|
||||
|
||||
### Phase 10: Deployment Validation
|
||||
- [ ] Deploy to test environment
|
||||
- [ ] Run integration test suite
|
||||
- [ ] Verify backward compatibility
|
||||
- [ ] Performance benchmarking
|
||||
- [ ] Production deployment
|
||||
|
||||
## 11. Success Criteria
|
||||
|
||||
The implementation is successful when:
|
||||
|
||||
1. **All 40+ methods** produce identical outputs to the original
|
||||
2. **Performance** is within 10% of original implementation
|
||||
3. **All tests** pass without modification
|
||||
4. **No breaking changes** to any client code
|
||||
5. **Security boundaries** are maintained
|
||||
6. **Event publishing** works identically
|
||||
7. **Error handling** matches original behavior
|
||||
|
||||
## 12. Critical Implementation Notes
|
||||
|
||||
1. **DO NOT** modify any function signatures
|
||||
2. **DO NOT** change any return types
|
||||
3. **DO NOT** add new required parameters
|
||||
4. **DO NOT** remove any functionality
|
||||
5. **ALWAYS** maintain user_id isolation
|
||||
6. **ALWAYS** publish events for state changes
|
||||
7. **ALWAYS** use transactions for multi-step operations
|
||||
8. **ALWAYS** handle errors exactly as original
|
||||
|
||||
This specification, when implemented correctly, will produce a drop-in replacement for the DatabaseManager that maintains 100% compatibility with the existing system.
|
||||
765
autogpt_platform/autogpt-rs/NOTIFICATION_SERVICE.md
Normal file
765
autogpt_platform/autogpt-rs/NOTIFICATION_SERVICE.md
Normal file
@@ -0,0 +1,765 @@
|
||||
# Notification Service Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
The AutoGPT Platform Notification Service is a RabbitMQ-based asynchronous notification system that handles various types of user notifications including real-time alerts, batched notifications, and scheduled summaries. The service supports email delivery via Postmark and system alerts via Discord.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **NotificationManager Service** (`notifications.py`)
|
||||
- AppService implementation with RabbitMQ integration
|
||||
- Processes notification queues asynchronously
|
||||
- Manages batching strategies and delivery timing
|
||||
- Handles email templating and sending
|
||||
|
||||
2. **RabbitMQ Message Broker**
|
||||
- Multiple queues for different notification strategies
|
||||
- Dead letter exchange for failed messages
|
||||
- Topic-based routing for message distribution
|
||||
|
||||
3. **Email Sender** (`email.py`)
|
||||
- Postmark integration for email delivery
|
||||
- Jinja2 template rendering
|
||||
- HTML email composition with unsubscribe headers
|
||||
|
||||
4. **Database Storage**
|
||||
- Notification batching tables
|
||||
- User preference storage
|
||||
- Email verification tracking
|
||||
|
||||
## Service Exposure Mechanism
|
||||
|
||||
### AppService Framework
|
||||
|
||||
The NotificationManager extends `AppService` which automatically exposes methods decorated with `@expose` as HTTP endpoints:
|
||||
|
||||
```python
|
||||
class NotificationManager(AppService):
|
||||
@expose
|
||||
def queue_weekly_summary(self):
|
||||
# Implementation
|
||||
|
||||
@expose
|
||||
def process_existing_batches(self, notification_types: list[NotificationType]):
|
||||
# Implementation
|
||||
|
||||
@expose
|
||||
async def discord_system_alert(self, content: str):
|
||||
# Implementation
|
||||
```
|
||||
|
||||
### Automatic HTTP Endpoint Creation
|
||||
|
||||
When the service starts, the AppService base class:
|
||||
1. Scans for methods with `@expose` decorator
|
||||
2. Creates FastAPI routes for each exposed method:
|
||||
- Route path: `/{method_name}`
|
||||
- HTTP method: POST
|
||||
- Endpoint handler: Generated via `_create_fastapi_endpoint()`
|
||||
|
||||
### Service Client Access
|
||||
|
||||
#### NotificationManagerClient
|
||||
```python
|
||||
class NotificationManagerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return NotificationManager
|
||||
|
||||
# Direct method references (sync)
|
||||
process_existing_batches = NotificationManager.process_existing_batches
|
||||
queue_weekly_summary = NotificationManager.queue_weekly_summary
|
||||
|
||||
# Async-to-sync conversion
|
||||
discord_system_alert = endpoint_to_sync(NotificationManager.discord_system_alert)
|
||||
```
|
||||
|
||||
#### Client Usage Pattern
|
||||
```python
|
||||
# Get client instance
|
||||
client = get_service_client(NotificationManagerClient)
|
||||
|
||||
# Call exposed methods via HTTP
|
||||
client.process_existing_batches([NotificationType.AGENT_RUN])
|
||||
client.queue_weekly_summary()
|
||||
client.discord_system_alert("System alert message")
|
||||
```
|
||||
|
||||
### HTTP Communication Details
|
||||
|
||||
1. **Service URL**: `http://{host}:{notification_service_port}`
|
||||
- Default port: 8007
|
||||
- Host: Configurable via settings
|
||||
|
||||
2. **Request Format**:
|
||||
- Method: POST
|
||||
- Path: `/{method_name}`
|
||||
- Body: JSON with method parameters
|
||||
|
||||
3. **Client Implementation**:
|
||||
- Uses `httpx` for HTTP requests
|
||||
- Automatic retry on connection failures
|
||||
- Configurable timeout (default from api_call_timeout)
|
||||
|
||||
### Direct Function Calls
|
||||
|
||||
The service also exposes two functions that can be called directly without going through the service client:
|
||||
|
||||
```python
|
||||
# Sync version - used by ExecutionManager
|
||||
def queue_notification(event: NotificationEventModel) -> NotificationResult
|
||||
|
||||
# Async version - used by credit system
|
||||
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult
|
||||
```
|
||||
|
||||
These functions:
|
||||
- Connect directly to RabbitMQ
|
||||
- Publish messages to appropriate queues
|
||||
- Return success/failure status
|
||||
- Are NOT exposed via HTTP
|
||||
|
||||
## Message Queuing Architecture
|
||||
|
||||
### RabbitMQ Configuration
|
||||
|
||||
#### Exchanges
|
||||
```python
|
||||
NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
|
||||
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
|
||||
```
|
||||
|
||||
#### Queues
|
||||
1. **immediate_notifications**
|
||||
- Routing Key: `notification.immediate.#`
|
||||
- Dead Letter: `failed.immediate`
|
||||
- For: Critical alerts, errors
|
||||
|
||||
2. **admin_notifications**
|
||||
- Routing Key: `notification.admin.#`
|
||||
- Dead Letter: `failed.admin`
|
||||
- For: Refund requests, system alerts
|
||||
|
||||
3. **summary_notifications**
|
||||
- Routing Key: `notification.summary.#`
|
||||
- Dead Letter: `failed.summary`
|
||||
- For: Daily/weekly summaries
|
||||
|
||||
4. **batch_notifications**
|
||||
- Routing Key: `notification.batch.#`
|
||||
- Dead Letter: `failed.batch`
|
||||
- For: Agent runs, batched events
|
||||
|
||||
5. **failed_notifications**
|
||||
- Routing Key: `failed.#`
|
||||
- For: All failed messages
|
||||
|
||||
### Queue Strategies (QueueType enum)
|
||||
|
||||
1. **IMMEDIATE**: Send right away (errors, critical notifications)
|
||||
2. **BATCH**: Batch for configured delay (agent runs)
|
||||
3. **SUMMARY**: Scheduled digest (daily/weekly summaries)
|
||||
4. **BACKOFF**: Exponential backoff strategy (defined but not fully implemented)
|
||||
5. **ADMIN**: Admin-only notifications
|
||||
|
||||
## Notification Types
|
||||
|
||||
### Enum Values (NotificationType)
|
||||
```python
|
||||
AGENT_RUN # Batch strategy, 1 day delay
|
||||
ZERO_BALANCE # Backoff strategy, 60 min delay
|
||||
LOW_BALANCE # Immediate strategy
|
||||
BLOCK_EXECUTION_FAILED # Backoff strategy, 60 min delay
|
||||
CONTINUOUS_AGENT_ERROR # Backoff strategy, 60 min delay
|
||||
DAILY_SUMMARY # Summary strategy
|
||||
WEEKLY_SUMMARY # Summary strategy
|
||||
MONTHLY_SUMMARY # Summary strategy
|
||||
REFUND_REQUEST # Admin strategy
|
||||
REFUND_PROCESSED # Admin strategy
|
||||
```
|
||||
|
||||
## Integration Points
|
||||
|
||||
### 1. Scheduler Integration
|
||||
The scheduler service (`backend.executor.scheduler`) imports monitoring functions that call the NotificationManagerClient:
|
||||
|
||||
```python
|
||||
from backend.monitoring import (
|
||||
process_existing_batches,
|
||||
process_weekly_summary,
|
||||
)
|
||||
|
||||
# These are scheduled as cron jobs
|
||||
```
|
||||
|
||||
### 2. Execution Manager Integration
|
||||
The ExecutionManager directly calls `queue_notification()` for:
|
||||
- Agent run completions
|
||||
- Low balance alerts
|
||||
|
||||
```python
|
||||
from backend.notifications.notifications import queue_notification
|
||||
|
||||
# Called after graph execution completes
|
||||
queue_notification(NotificationEventModel(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(...)
|
||||
))
|
||||
```
|
||||
|
||||
### 3. Credit System Integration
|
||||
The credit system uses `queue_notification_async()` for:
|
||||
- Refund requests
|
||||
- Refund processed notifications
|
||||
|
||||
```python
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
|
||||
await queue_notification_async(NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.REFUND_REQUEST,
|
||||
data=RefundRequestData(...)
|
||||
))
|
||||
```
|
||||
|
||||
### 4. Monitoring Module Wrappers
|
||||
The monitoring module provides wrapper functions that are used by the scheduler:
|
||||
|
||||
```python
|
||||
# backend/monitoring/notification_monitor.py
|
||||
def process_existing_batches(**kwargs):
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
get_notification_manager_client().process_existing_batches(
|
||||
args.notification_types
|
||||
)
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
get_notification_manager_client().queue_weekly_summary()
|
||||
```
|
||||
|
||||
## Data Models
|
||||
|
||||
### Base Event Model
|
||||
```typescript
|
||||
interface BaseEventModel {
|
||||
type: NotificationType;
|
||||
user_id: string;
|
||||
created_at: string; // ISO datetime with timezone
|
||||
}
|
||||
```
|
||||
|
||||
### Notification Event Model
|
||||
```typescript
|
||||
interface NotificationEventModel<T> extends BaseEventModel {
|
||||
data: T;
|
||||
}
|
||||
```
|
||||
|
||||
### Notification Data Types
|
||||
|
||||
#### AgentRunData
|
||||
```typescript
|
||||
interface AgentRunData {
|
||||
agent_name: string;
|
||||
credits_used: number;
|
||||
execution_time: number;
|
||||
node_count: number;
|
||||
graph_id: string;
|
||||
outputs: Array<Record<string, any>>;
|
||||
}
|
||||
```
|
||||
|
||||
#### ZeroBalanceData
|
||||
```typescript
|
||||
interface ZeroBalanceData {
|
||||
last_transaction: number;
|
||||
last_transaction_time: string; // ISO datetime with timezone
|
||||
top_up_link: string;
|
||||
}
|
||||
```
|
||||
|
||||
#### LowBalanceData
|
||||
```typescript
|
||||
interface LowBalanceData {
|
||||
agent_name: string;
|
||||
current_balance: number; // credits (100 = $1)
|
||||
billing_page_link: string;
|
||||
shortfall: number;
|
||||
}
|
||||
```
|
||||
|
||||
#### BlockExecutionFailedData
|
||||
```typescript
|
||||
interface BlockExecutionFailedData {
|
||||
block_name: string;
|
||||
block_id: string;
|
||||
error_message: string;
|
||||
graph_id: string;
|
||||
node_id: string;
|
||||
execution_id: string;
|
||||
}
|
||||
```
|
||||
|
||||
#### ContinuousAgentErrorData
|
||||
```typescript
|
||||
interface ContinuousAgentErrorData {
|
||||
agent_name: string;
|
||||
error_message: string;
|
||||
graph_id: string;
|
||||
execution_id: string;
|
||||
start_time: string; // ISO datetime with timezone
|
||||
error_time: string; // ISO datetime with timezone
|
||||
attempts: number;
|
||||
}
|
||||
```
|
||||
|
||||
#### Summary Data Types
|
||||
```typescript
|
||||
interface BaseSummaryData {
|
||||
total_credits_used: number;
|
||||
total_executions: number;
|
||||
most_used_agent: string;
|
||||
total_execution_time: number;
|
||||
successful_runs: number;
|
||||
failed_runs: number;
|
||||
average_execution_time: number;
|
||||
cost_breakdown: Record<string, number>;
|
||||
}
|
||||
|
||||
interface DailySummaryData extends BaseSummaryData {
|
||||
date: string; // ISO datetime with timezone
|
||||
}
|
||||
|
||||
interface WeeklySummaryData extends BaseSummaryData {
|
||||
start_date: string; // ISO datetime with timezone
|
||||
end_date: string; // ISO datetime with timezone
|
||||
}
|
||||
```
|
||||
|
||||
#### RefundRequestData
|
||||
```typescript
|
||||
interface RefundRequestData {
|
||||
user_id: string;
|
||||
user_name: string;
|
||||
user_email: string;
|
||||
transaction_id: string;
|
||||
refund_request_id: string;
|
||||
reason: string;
|
||||
amount: number;
|
||||
balance: number;
|
||||
}
|
||||
```
|
||||
|
||||
### Summary Parameters
|
||||
```typescript
|
||||
interface BaseSummaryParams {
|
||||
start_date: string; // ISO datetime with timezone
|
||||
end_date: string; // ISO datetime with timezone
|
||||
}
|
||||
|
||||
interface DailySummaryParams extends BaseSummaryParams {
|
||||
date: string; // ISO datetime with timezone
|
||||
}
|
||||
|
||||
interface WeeklySummaryParams extends BaseSummaryParams {
|
||||
start_date: string; // ISO datetime with timezone
|
||||
end_date: string; // ISO datetime with timezone
|
||||
}
|
||||
```
|
||||
|
||||
## Database Schema
|
||||
|
||||
### NotificationEvent Table
|
||||
```sql
|
||||
model NotificationEvent {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
UserNotificationBatch UserNotificationBatch? @relation
|
||||
userNotificationBatchId String?
|
||||
type NotificationType
|
||||
data Json
|
||||
@@index([userNotificationBatchId])
|
||||
}
|
||||
```
|
||||
|
||||
### UserNotificationBatch Table
|
||||
```sql
|
||||
model UserNotificationBatch {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
userId String
|
||||
User User @relation
|
||||
type NotificationType
|
||||
Notifications NotificationEvent[]
|
||||
@@unique([userId, type])
|
||||
}
|
||||
```
|
||||
|
||||
## API Methods
|
||||
|
||||
### Exposed Service Methods (via HTTP)
|
||||
|
||||
#### queue_weekly_summary()
|
||||
- **HTTP Endpoint**: `POST /queue_weekly_summary`
|
||||
- **Purpose**: Triggers weekly summary generation for all active users
|
||||
- **Process**:
|
||||
1. Runs in background executor
|
||||
2. Queries users active in last 7 days
|
||||
3. Queues summary notification for each user
|
||||
- **Used by**: Scheduler service (via cron)
|
||||
|
||||
#### process_existing_batches(notification_types: list[NotificationType])
|
||||
- **HTTP Endpoint**: `POST /process_existing_batches`
|
||||
- **Purpose**: Processes aged-out batches for specified notification types
|
||||
- **Process**:
|
||||
1. Runs in background executor
|
||||
2. Retrieves all batches for given types
|
||||
3. Checks if oldest message exceeds max delay
|
||||
4. Sends batched email if aged out
|
||||
5. Clears processed batches
|
||||
- **Used by**: Scheduler service (via cron)
|
||||
|
||||
#### discord_system_alert(content: str)
|
||||
- **HTTP Endpoint**: `POST /discord_system_alert`
|
||||
- **Purpose**: Sends system alerts to Discord channel
|
||||
- **Async**: Yes (converted to sync by client)
|
||||
- **Used by**: Monitoring services
|
||||
|
||||
### Direct Queue Functions (not via HTTP)
|
||||
|
||||
#### queue_notification(event: NotificationEventModel) -> NotificationResult
|
||||
- **Purpose**: Queue a notification (sync version)
|
||||
- **Used by**: ExecutionManager (same process)
|
||||
- **Direct RabbitMQ**: Yes
|
||||
|
||||
#### queue_notification_async(event: NotificationEventModel) -> NotificationResult
|
||||
- **Purpose**: Queue a notification (async version)
|
||||
- **Used by**: Credit system (async context)
|
||||
- **Direct RabbitMQ**: Yes
|
||||
|
||||
## Message Processing Flow
|
||||
|
||||
### 1. Message Routing
|
||||
```python
|
||||
def get_routing_key(event_type: NotificationType) -> str:
|
||||
strategy = NotificationTypeOverride(event_type).strategy
|
||||
if strategy == QueueType.IMMEDIATE:
|
||||
return f"notification.immediate.{event_type.value}"
|
||||
elif strategy == QueueType.BATCH:
|
||||
return f"notification.batch.{event_type.value}"
|
||||
# ... etc
|
||||
```
|
||||
|
||||
### 2. Queue Processing Methods
|
||||
|
||||
#### _process_immediate(message: str) -> bool
|
||||
1. Parse message to NotificationEventModel
|
||||
2. Retrieve user email
|
||||
3. Check user preferences and email verification
|
||||
4. Send email immediately via EmailSender
|
||||
5. Return True if successful
|
||||
|
||||
#### _process_batch(message: str) -> bool
|
||||
1. Parse message to NotificationEventModel
|
||||
2. Add to user's notification batch
|
||||
3. Check if batch is old enough (based on delay)
|
||||
4. If aged out:
|
||||
- Retrieve all batch messages
|
||||
- Send combined email
|
||||
- Clear batch
|
||||
5. Return True if processed or batched
|
||||
|
||||
#### _process_summary(message: str) -> bool
|
||||
1. Parse message to SummaryParamsEventModel
|
||||
2. Gather summary data (credits, executions, etc.)
|
||||
- **Note**: Currently returns hardcoded placeholder data
|
||||
3. Format and send summary email
|
||||
4. Return True if successful
|
||||
|
||||
#### _process_admin_message(message: str) -> bool
|
||||
1. Parse message
|
||||
2. Send to configured admin email
|
||||
3. No user preference checks
|
||||
4. Return True if successful
|
||||
|
||||
## Email Delivery
|
||||
|
||||
### EmailSender Class
|
||||
|
||||
#### Template Loading
|
||||
- Base template: `templates/base.html.jinja2`
|
||||
- Notification templates: `templates/{notification_type}.html.jinja2`
|
||||
- Subject templates from NotificationTypeOverride
|
||||
- **Note**: Templates use `.html.jinja2` extension, not just `.html`
|
||||
|
||||
#### Email Composition
|
||||
```python
|
||||
def send_templated(
|
||||
notification: NotificationType,
|
||||
user_email: str,
|
||||
data: NotificationEventModel | list[NotificationEventModel],
|
||||
user_unsub_link: str | None = None
|
||||
)
|
||||
```
|
||||
|
||||
#### Postmark Integration
|
||||
- API Token: `settings.secrets.postmark_server_api_token`
|
||||
- Sender Email: `settings.config.postmark_sender_email`
|
||||
- Headers:
|
||||
- `List-Unsubscribe-Post: List-Unsubscribe=One-Click`
|
||||
- `List-Unsubscribe: <{unsubscribe_link}>`
|
||||
|
||||
## User Preferences and Permissions
|
||||
|
||||
### Email Verification Check
|
||||
```python
|
||||
validated_email = get_db().get_user_email_verification(user_id)
|
||||
```
|
||||
|
||||
### Notification Preferences
|
||||
```python
|
||||
preferences = get_db().get_user_notification_preference(user_id).preferences
|
||||
# Returns dict[NotificationType, bool]
|
||||
```
|
||||
|
||||
### Preference Fields in User Model
|
||||
- `notifyOnAgentRun`
|
||||
- `notifyOnZeroBalance`
|
||||
- `notifyOnLowBalance`
|
||||
- `notifyOnBlockExecutionFailed`
|
||||
- `notifyOnContinuousAgentError`
|
||||
- `notifyOnDailySummary`
|
||||
- `notifyOnWeeklySummary`
|
||||
- `notifyOnMonthlySummary`
|
||||
|
||||
### Unsubscribe Link Generation
|
||||
```python
|
||||
def generate_unsubscribe_link(user_id: str) -> str:
|
||||
# HMAC-SHA256 signed token
|
||||
# Format: base64(user_id:signature_hex)
|
||||
# URL: {platform_base_url}/api/email/unsubscribe?token={token}
|
||||
```
|
||||
|
||||
## Batching Logic
|
||||
|
||||
### Batch Delays (get_batch_delay)
|
||||
|
||||
**Note**: The delay configuration exists for multiple notification types, but only notifications with `QueueType.BATCH` strategy actually use batching. Others use different strategies:
|
||||
|
||||
- `AGENT_RUN`: 1 day (Strategy: BATCH - actually uses batching)
|
||||
- `ZERO_BALANCE`: 60 minutes configured (Strategy: BACKOFF - not batched)
|
||||
- `LOW_BALANCE`: 60 minutes configured (Strategy: IMMEDIATE - sent immediately)
|
||||
- `BLOCK_EXECUTION_FAILED`: 60 minutes configured (Strategy: BACKOFF - not batched)
|
||||
- `CONTINUOUS_AGENT_ERROR`: 60 minutes configured (Strategy: BACKOFF - not batched)
|
||||
|
||||
### Batch Processing
|
||||
1. Messages added to UserNotificationBatch
|
||||
2. Oldest message timestamp tracked
|
||||
3. When `oldest_timestamp + delay < now()`:
|
||||
- Batch is processed
|
||||
- All messages sent in single email
|
||||
- Batch cleared
|
||||
|
||||
## Service Lifecycle
|
||||
|
||||
### Startup
|
||||
1. Initialize FastAPI app with exposed endpoints
|
||||
2. Start HTTP server on port 8007
|
||||
3. Initialize RabbitMQ connection
|
||||
4. Create/verify exchanges and queues
|
||||
5. Set up queue consumers
|
||||
6. Start processing loop
|
||||
|
||||
### Main Loop
|
||||
```python
|
||||
while self.running:
|
||||
await self._run_queue(immediate_queue, self._process_immediate, ...)
|
||||
await self._run_queue(admin_queue, self._process_admin_message, ...)
|
||||
await self._run_queue(batch_queue, self._process_batch, ...)
|
||||
await self._run_queue(summary_queue, self._process_summary, ...)
|
||||
await asyncio.sleep(0.1)
|
||||
```
|
||||
|
||||
### Shutdown
|
||||
1. Set `running = False`
|
||||
2. Disconnect RabbitMQ
|
||||
3. Cleanup resources
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
```python
|
||||
# Service Configuration
|
||||
notification_service_port: int = 8007
|
||||
|
||||
# Email Configuration
|
||||
postmark_sender_email: str = "invalid@invalid.com"
|
||||
refund_notification_email: str = "refund@agpt.co"
|
||||
|
||||
# Security
|
||||
unsubscribe_secret_key: str = ""
|
||||
|
||||
# Secrets
|
||||
postmark_server_api_token: str = ""
|
||||
postmark_webhook_token: str = ""
|
||||
discord_bot_token: str = ""
|
||||
|
||||
# Platform URLs
|
||||
platform_base_url: str
|
||||
frontend_base_url: str
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Message Processing Errors
|
||||
- Failed messages sent to dead letter queue
|
||||
- Validation errors logged but don't crash service
|
||||
- Connection errors trigger retry with `@continuous_retry()`
|
||||
|
||||
### RabbitMQ ACK/NACK Protocol
|
||||
- Success: `message.ack()`
|
||||
- Failure: `message.reject(requeue=False)`
|
||||
- Timeout/Queue empty: Continue loop
|
||||
|
||||
### HTTP Endpoint Errors
|
||||
- Wrapped in RemoteCallError for client
|
||||
- Automatic retry available via client configuration
|
||||
- Connection failures tracked and logged
|
||||
|
||||
## System Integrations
|
||||
|
||||
### DatabaseManagerClient
|
||||
- User email retrieval
|
||||
- Email verification status
|
||||
- Notification preferences
|
||||
- Batch management
|
||||
- Active user queries
|
||||
|
||||
### Discord Integration
|
||||
- Uses SendDiscordMessageBlock
|
||||
- Configured via discord_bot_token
|
||||
- For system alerts only
|
||||
|
||||
## Implementation Checklist
|
||||
|
||||
1. **Core Service**
|
||||
- [ ] AppService implementation with @expose decorators
|
||||
- [ ] FastAPI endpoint generation
|
||||
- [ ] RabbitMQ connection management
|
||||
- [ ] Queue consumer setup
|
||||
- [ ] Message routing logic
|
||||
|
||||
2. **Service Client**
|
||||
- [ ] NotificationManagerClient implementation
|
||||
- [ ] HTTP client configuration
|
||||
- [ ] Method mapping to service endpoints
|
||||
- [ ] Async-to-sync conversions
|
||||
|
||||
3. **Message Processing**
|
||||
- [ ] Parse and validate all notification types
|
||||
- [ ] Implement all queue strategies
|
||||
- [ ] Batch management with delays
|
||||
- [ ] Summary data gathering
|
||||
|
||||
4. **Email Delivery**
|
||||
- [ ] Postmark integration
|
||||
- [ ] Template loading and rendering
|
||||
- [ ] Unsubscribe header support
|
||||
- [ ] HTML email composition
|
||||
|
||||
5. **User Management**
|
||||
- [ ] Preference checking
|
||||
- [ ] Email verification
|
||||
- [ ] Unsubscribe link generation
|
||||
- [ ] Daily limit tracking
|
||||
|
||||
6. **Batching System**
|
||||
- [ ] Database batch operations
|
||||
- [ ] Age-out checking
|
||||
- [ ] Batch clearing after send
|
||||
- [ ] Oldest message tracking
|
||||
|
||||
7. **Error Handling**
|
||||
- [ ] Dead letter queue routing
|
||||
- [ ] Message rejection on failure
|
||||
- [ ] Continuous retry wrapper
|
||||
- [ ] Validation error logging
|
||||
|
||||
8. **Scheduled Operations**
|
||||
- [ ] Weekly summary generation
|
||||
- [ ] Batch processing triggers
|
||||
- [ ] Background executor usage
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Service-to-Service Communication**:
|
||||
- HTTP endpoints only accessible internally
|
||||
- No authentication on service endpoints (internal network only)
|
||||
- Service discovery via host/port configuration
|
||||
|
||||
2. **User Security**:
|
||||
- Email verification required for all user notifications
|
||||
- Unsubscribe tokens HMAC-signed
|
||||
- User preferences enforced
|
||||
|
||||
3. **Admin Notifications**:
|
||||
- Separate queue, no user preference checks
|
||||
- Fixed admin email configuration
|
||||
|
||||
## Testing Considerations
|
||||
|
||||
1. **Unit Tests**
|
||||
- Message parsing and validation
|
||||
- Routing key generation
|
||||
- Batch delay calculations
|
||||
- Template rendering
|
||||
|
||||
2. **Integration Tests**
|
||||
- HTTP endpoint accessibility
|
||||
- Service client method calls
|
||||
- RabbitMQ message flow
|
||||
- Database batch operations
|
||||
- Email sending (mock Postmark)
|
||||
|
||||
3. **Load Tests**
|
||||
- High volume message processing
|
||||
- Concurrent HTTP requests
|
||||
- Batch accumulation limits
|
||||
- Memory usage under load
|
||||
|
||||
## Implementation Status Notes
|
||||
|
||||
1. **Backoff Strategy**: While `QueueType.BACKOFF` is defined and used by several notification types (ZERO_BALANCE, BLOCK_EXECUTION_FAILED, CONTINUOUS_AGENT_ERROR), the actual exponential backoff processing logic is not implemented. These messages are routed to immediate queue.
|
||||
|
||||
2. **Summary Data**: The `_gather_summary_data()` method currently returns hardcoded placeholder values rather than querying actual execution data from the database.
|
||||
|
||||
3. **Batch Processing**: Only `AGENT_RUN` notifications actually use batch processing. Other notification types with configured delays use different strategies (IMMEDIATE or BACKOFF).
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. **Additional Channels**
|
||||
- SMS notifications (not implemented)
|
||||
- Webhook notifications (not implemented)
|
||||
- In-app notifications
|
||||
|
||||
2. **Advanced Batching**
|
||||
- Dynamic batch sizes
|
||||
- Priority-based processing
|
||||
- Custom delay configurations
|
||||
|
||||
3. **Analytics**
|
||||
- Delivery tracking
|
||||
- Open/click rates
|
||||
- Notification effectiveness metrics
|
||||
|
||||
4. **Service Improvements**
|
||||
- Authentication for HTTP endpoints
|
||||
- Rate limiting per user
|
||||
- Circuit breaker patterns
|
||||
- Implement actual backoff processing for BACKOFF strategy
|
||||
- Implement real summary data gathering
|
||||
474
autogpt_platform/autogpt-rs/SCHEDULER.md
Normal file
474
autogpt_platform/autogpt-rs/SCHEDULER.md
Normal file
@@ -0,0 +1,474 @@
|
||||
# AutoGPT Platform Scheduler Technical Specification
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This document provides a comprehensive technical specification for the AutoGPT Platform Scheduler service. The scheduler is responsible for managing scheduled graph executions, system monitoring tasks, and periodic maintenance operations. This specification is designed to enable a complete reimplementation that maintains 100% compatibility with the existing system.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [System Architecture](#system-architecture)
|
||||
2. [Service Implementation](#service-implementation)
|
||||
3. [Data Models](#data-models)
|
||||
4. [API Endpoints](#api-endpoints)
|
||||
5. [Database Schema](#database-schema)
|
||||
6. [External Dependencies](#external-dependencies)
|
||||
7. [Authentication & Authorization](#authentication--authorization)
|
||||
8. [Process Management](#process-management)
|
||||
9. [Error Handling](#error-handling)
|
||||
10. [Configuration](#configuration)
|
||||
11. [Testing Strategy](#testing-strategy)
|
||||
|
||||
## System Architecture
|
||||
|
||||
### Overview
|
||||
|
||||
The scheduler operates as an independent microservice within the AutoGPT platform, implementing the `AppService` base class pattern. It runs on a dedicated port (default: 8003) and exposes HTTP/JSON-RPC endpoints for communication with other services.
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **Scheduler Service** (`backend/executor/scheduler.py:156`)
|
||||
- Extends `AppService` base class
|
||||
- Manages APScheduler instance with multiple jobstores
|
||||
- Handles lifecycle management and graceful shutdown
|
||||
|
||||
2. **Scheduler Client** (`backend/executor/scheduler.py:354`)
|
||||
- Extends `AppServiceClient` base class
|
||||
- Provides async/sync method wrappers for RPC calls
|
||||
- Implements automatic retry and connection pooling
|
||||
|
||||
3. **Entry Points**
|
||||
- Main executable: `backend/scheduler.py`
|
||||
- Service launcher: `backend/app.py`
|
||||
|
||||
## Service Implementation
|
||||
|
||||
### Base Service Pattern
|
||||
|
||||
```python
|
||||
class Scheduler(AppService):
|
||||
scheduler: BlockingScheduler
|
||||
|
||||
def __init__(self, register_system_tasks: bool = True):
|
||||
self.register_system_tasks = register_system_tasks
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return config.execution_scheduler_port # Default: 8003
|
||||
|
||||
@classmethod
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size # Default: 3
|
||||
|
||||
def run_service(self):
|
||||
# Initialize scheduler with jobstores
|
||||
# Register system tasks if enabled
|
||||
# Start scheduler blocking loop
|
||||
|
||||
def cleanup(self):
|
||||
# Graceful shutdown of scheduler
|
||||
# Wait=False for immediate termination
|
||||
```
|
||||
|
||||
### Jobstore Configuration
|
||||
|
||||
The scheduler uses three distinct jobstores:
|
||||
|
||||
1. **EXECUTION** (`Jobstores.EXECUTION.value`)
|
||||
- Type: SQLAlchemyJobStore
|
||||
- Table: `apscheduler_jobs`
|
||||
- Purpose: Graph execution schedules
|
||||
- Persistence: Required
|
||||
|
||||
2. **BATCHED_NOTIFICATIONS** (`Jobstores.BATCHED_NOTIFICATIONS.value`)
|
||||
- Type: SQLAlchemyJobStore
|
||||
- Table: `apscheduler_jobs_batched_notifications`
|
||||
- Purpose: Batched notification processing
|
||||
- Persistence: Required
|
||||
|
||||
3. **WEEKLY_NOTIFICATIONS** (`Jobstores.WEEKLY_NOTIFICATIONS.value`)
|
||||
- Type: MemoryJobStore
|
||||
- Purpose: Weekly summary notifications
|
||||
- Persistence: Not required
|
||||
|
||||
### System Tasks
|
||||
|
||||
When `register_system_tasks=True`, the following monitoring tasks are registered:
|
||||
|
||||
1. **Weekly Summary Processing**
|
||||
- Job ID: `process_weekly_summary`
|
||||
- Schedule: `0 * * * *` (hourly)
|
||||
- Function: `monitoring.process_weekly_summary`
|
||||
- Jobstore: WEEKLY_NOTIFICATIONS
|
||||
|
||||
2. **Late Execution Monitoring**
|
||||
- Job ID: `report_late_executions`
|
||||
- Schedule: Interval (config.execution_late_notification_threshold_secs)
|
||||
- Function: `monitoring.report_late_executions`
|
||||
- Jobstore: EXECUTION
|
||||
|
||||
3. **Block Error Rate Monitoring**
|
||||
- Job ID: `report_block_error_rates`
|
||||
- Schedule: Interval (config.block_error_rate_check_interval_secs)
|
||||
- Function: `monitoring.report_block_error_rates`
|
||||
- Jobstore: EXECUTION
|
||||
|
||||
4. **Cloud Storage Cleanup**
|
||||
- Job ID: `cleanup_expired_files`
|
||||
- Schedule: Interval (config.cloud_storage_cleanup_interval_hours * 3600)
|
||||
- Function: `cleanup_expired_files`
|
||||
- Jobstore: EXECUTION
|
||||
|
||||
## Data Models
|
||||
|
||||
### GraphExecutionJobArgs
|
||||
|
||||
```python
|
||||
class GraphExecutionJobArgs(BaseModel):
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
cron: str
|
||||
input_data: BlockInput
|
||||
input_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
```
|
||||
|
||||
### GraphExecutionJobInfo
|
||||
|
||||
```python
|
||||
class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
next_run_time: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(job_args: GraphExecutionJobArgs, job_obj: JobObj) -> "GraphExecutionJobInfo":
|
||||
return GraphExecutionJobInfo(
|
||||
id=job_obj.id,
|
||||
name=job_obj.name,
|
||||
next_run_time=job_obj.next_run_time.isoformat(),
|
||||
**job_args.model_dump(),
|
||||
)
|
||||
```
|
||||
|
||||
### NotificationJobArgs
|
||||
|
||||
```python
|
||||
class NotificationJobArgs(BaseModel):
|
||||
notification_types: list[NotificationType]
|
||||
cron: str
|
||||
```
|
||||
|
||||
### CredentialsMetaInput
|
||||
|
||||
```python
|
||||
class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
id: str
|
||||
title: Optional[str] = None
|
||||
provider: CP
|
||||
type: CT
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
All endpoints are exposed via the `@expose` decorator and follow HTTP POST JSON-RPC pattern.
|
||||
|
||||
### 1. Add Graph Execution Schedule
|
||||
|
||||
**Endpoint**: `/add_graph_execution_schedule`
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"user_id": "string",
|
||||
"graph_id": "string",
|
||||
"graph_version": "integer",
|
||||
"cron": "string (crontab format)",
|
||||
"input_data": {},
|
||||
"input_credentials": {},
|
||||
"name": "string (optional)"
|
||||
}
|
||||
```
|
||||
|
||||
**Response**: `GraphExecutionJobInfo`
|
||||
|
||||
**Behavior**:
|
||||
- Creates APScheduler job with CronTrigger
|
||||
- Uses job kwargs to store GraphExecutionJobArgs
|
||||
- Sets `replace_existing=True` to allow updates
|
||||
- Returns job info with generated ID and next run time
|
||||
|
||||
### 2. Delete Graph Execution Schedule
|
||||
|
||||
**Endpoint**: `/delete_graph_execution_schedule`
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"schedule_id": "string",
|
||||
"user_id": "string"
|
||||
}
|
||||
```
|
||||
|
||||
**Response**: `GraphExecutionJobInfo`
|
||||
|
||||
**Behavior**:
|
||||
- Validates schedule exists in EXECUTION jobstore
|
||||
- Verifies user_id matches job's user_id
|
||||
- Removes job from scheduler
|
||||
- Returns deleted job info
|
||||
|
||||
**Errors**:
|
||||
- `NotFoundError`: If job doesn't exist
|
||||
- `NotAuthorizedError`: If user_id doesn't match
|
||||
|
||||
### 3. Get Graph Execution Schedules
|
||||
|
||||
**Endpoint**: `/get_graph_execution_schedules`
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"graph_id": "string (optional)",
|
||||
"user_id": "string (optional)"
|
||||
}
|
||||
```
|
||||
|
||||
**Response**: `list[GraphExecutionJobInfo]`
|
||||
|
||||
**Behavior**:
|
||||
- Retrieves all jobs from EXECUTION jobstore
|
||||
- Filters by graph_id and/or user_id if provided
|
||||
- Validates job kwargs as GraphExecutionJobArgs
|
||||
- Skips invalid jobs (ValidationError)
|
||||
- Only returns jobs with next_run_time set
|
||||
|
||||
### 4. System Task Endpoints
|
||||
|
||||
- `/execute_process_existing_batches` - Trigger batch processing
|
||||
- `/execute_process_weekly_summary` - Trigger weekly summary
|
||||
- `/execute_report_late_executions` - Trigger late execution report
|
||||
- `/execute_report_block_error_rates` - Trigger error rate report
|
||||
- `/execute_cleanup_expired_files` - Trigger file cleanup
|
||||
|
||||
### 5. Health Check
|
||||
|
||||
**Endpoints**: `/health_check`, `/health_check_async`
|
||||
**Methods**: POST, GET
|
||||
**Response**: "OK"
|
||||
|
||||
## Database Schema
|
||||
|
||||
### APScheduler Tables
|
||||
|
||||
The scheduler relies on APScheduler's SQLAlchemy jobstore schema:
|
||||
|
||||
1. **apscheduler_jobs**
|
||||
- id: VARCHAR (PRIMARY KEY)
|
||||
- next_run_time: FLOAT
|
||||
- job_state: BLOB/BYTEA (pickled job data)
|
||||
|
||||
2. **apscheduler_jobs_batched_notifications**
|
||||
- Same schema as above
|
||||
- Separate table for notification jobs
|
||||
|
||||
### Database Configuration
|
||||
|
||||
- URL extraction from `DIRECT_URL` environment variable
|
||||
- Schema extraction from URL query parameter
|
||||
- Connection pooling: `pool_size=db_pool_size()`, `max_overflow=0`
|
||||
- Metadata schema binding for multi-schema support
|
||||
|
||||
## External Dependencies
|
||||
|
||||
### Required Services
|
||||
|
||||
1. **PostgreSQL Database**
|
||||
- Connection via `DIRECT_URL` environment variable
|
||||
- Schema support via URL parameter
|
||||
- APScheduler job persistence
|
||||
|
||||
2. **ExecutionManager** (via execution_utils)
|
||||
- Function: `add_graph_execution`
|
||||
- Called by: `execute_graph` job function
|
||||
- Purpose: Create graph execution entries
|
||||
|
||||
3. **NotificationManager** (via monitoring module)
|
||||
- Functions: `process_existing_batches`, `queue_weekly_summary`
|
||||
- Purpose: Notification processing
|
||||
|
||||
4. **Cloud Storage** (via util.cloud_storage)
|
||||
- Function: `cleanup_expired_files_async`
|
||||
- Purpose: File expiration management
|
||||
|
||||
### Python Dependencies
|
||||
|
||||
```
|
||||
apscheduler>=3.10.0
|
||||
sqlalchemy
|
||||
pydantic>=2.0
|
||||
httpx
|
||||
uvicorn
|
||||
fastapi
|
||||
python-dotenv
|
||||
tenacity
|
||||
```
|
||||
|
||||
## Authentication & Authorization
|
||||
|
||||
### Service-Level Authentication
|
||||
|
||||
- No authentication required between internal services
|
||||
- Services communicate via trusted internal network
|
||||
- Host/port configuration via environment variables
|
||||
|
||||
### User-Level Authorization
|
||||
|
||||
- Authorization check in `delete_graph_execution_schedule`:
|
||||
- Validates `user_id` matches job's `user_id`
|
||||
- Raises `NotAuthorizedError` on mismatch
|
||||
- No authorization for read operations (security consideration)
|
||||
|
||||
## Process Management
|
||||
|
||||
### Startup Sequence
|
||||
|
||||
1. Load environment variables via `dotenv.load_dotenv()`
|
||||
2. Extract database URL and schema
|
||||
3. Initialize BlockingScheduler with configured jobstores
|
||||
4. Register system tasks (if enabled)
|
||||
5. Add job execution listener
|
||||
6. Start scheduler (blocking)
|
||||
|
||||
### Shutdown Sequence
|
||||
|
||||
1. Receive SIGTERM/SIGINT signal
|
||||
2. Call `cleanup()` method
|
||||
3. Shutdown scheduler with `wait=False`
|
||||
4. Terminate process
|
||||
|
||||
### Multi-Process Architecture
|
||||
|
||||
- Runs as independent process via `AppProcess`
|
||||
- Started by `run_processes()` in app.py
|
||||
- Can run in foreground or background mode
|
||||
- Automatic signal handling for graceful shutdown
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Job Execution Errors
|
||||
|
||||
- Listener on `EVENT_JOB_ERROR` logs failures
|
||||
- Errors in job functions are caught and logged
|
||||
- Jobs continue to run on schedule despite failures
|
||||
|
||||
### RPC Communication Errors
|
||||
|
||||
- Automatic retry via `@conn_retry` decorator
|
||||
- Configurable retry count and timeout
|
||||
- Connection pooling with self-healing
|
||||
|
||||
### Database Connection Errors
|
||||
|
||||
- APScheduler handles reconnection automatically
|
||||
- Pool exhaustion prevented by `max_overflow=0`
|
||||
- Connection errors logged but don't crash service
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `DIRECT_URL`: PostgreSQL connection string (required)
|
||||
- `{SERVICE_NAME}_HOST`: Override service host
|
||||
- Standard logging configuration
|
||||
|
||||
### Config Settings (via Config class)
|
||||
|
||||
```python
|
||||
execution_scheduler_port: int = 8003
|
||||
scheduler_db_pool_size: int = 3
|
||||
execution_late_notification_threshold_secs: int
|
||||
block_error_rate_check_interval_secs: int
|
||||
cloud_storage_cleanup_interval_hours: int
|
||||
pyro_host: str = "localhost"
|
||||
pyro_client_comm_timeout: float = 15
|
||||
pyro_client_comm_retry: int = 3
|
||||
rpc_client_call_timeout: int = 300
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Tests
|
||||
|
||||
1. Mock APScheduler for job management tests
|
||||
2. Mock database connections
|
||||
3. Test each RPC endpoint independently
|
||||
4. Verify job serialization/deserialization
|
||||
|
||||
### Integration Tests
|
||||
|
||||
1. Test with real PostgreSQL instance
|
||||
2. Verify job persistence across restarts
|
||||
3. Test concurrent job execution
|
||||
4. Validate cron expression parsing
|
||||
|
||||
### Critical Test Cases
|
||||
|
||||
1. **Job Persistence**: Jobs survive scheduler restart
|
||||
2. **User Isolation**: Users can only delete their own jobs
|
||||
3. **Concurrent Access**: Multiple clients can add/remove jobs
|
||||
4. **Error Recovery**: Service recovers from database outages
|
||||
5. **Resource Cleanup**: No memory/connection leaks
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
### Key Design Decisions
|
||||
|
||||
1. **BlockingScheduler vs AsyncIOScheduler**: Uses BlockingScheduler for simplicity and compatibility with multiprocessing architecture
|
||||
|
||||
2. **Job Storage**: All job arguments stored in kwargs, not in job name/id
|
||||
|
||||
3. **Separate Jobstores**: Isolation between execution and notification jobs
|
||||
|
||||
4. **No Authentication**: Relies on network isolation for security
|
||||
|
||||
### Migration Considerations
|
||||
|
||||
1. APScheduler job format must be preserved exactly
|
||||
2. Database schema cannot change without migration
|
||||
3. RPC protocol must maintain compatibility
|
||||
4. Environment variables must match existing deployment
|
||||
|
||||
### Performance Considerations
|
||||
|
||||
1. Database pool size limited to prevent exhaustion
|
||||
2. No job result storage (fire-and-forget pattern)
|
||||
3. Minimal logging in hot paths
|
||||
4. Connection reuse via pooling
|
||||
|
||||
## Appendix: Critical Implementation Details
|
||||
|
||||
### Event Loop Management
|
||||
|
||||
```python
|
||||
@thread_cached
|
||||
def get_event_loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
def execute_graph(**kwargs):
|
||||
get_event_loop().run_until_complete(_execute_graph(**kwargs))
|
||||
```
|
||||
|
||||
### Job Function Execution Context
|
||||
|
||||
- Jobs run in scheduler's process space
|
||||
- Each job gets fresh event loop
|
||||
- No shared state between job executions
|
||||
- Exceptions logged but don't affect scheduler
|
||||
|
||||
### Cron Expression Format
|
||||
|
||||
- Uses standard crontab format via `CronTrigger.from_crontab()`
|
||||
- Supports: minute hour day month day_of_week
|
||||
- Special strings: @yearly, @monthly, @weekly, @daily, @hourly
|
||||
|
||||
This specification provides all necessary details to reimplement the scheduler service while maintaining 100% compatibility with the existing system. Any deviation from these specifications may result in system incompatibility.
|
||||
85
autogpt_platform/autogpt-rs/websocket/.github/workflows/ci.yml
vendored
Normal file
85
autogpt_platform/autogpt-rs/websocket/.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
RUSTFLAGS: "-D warnings"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
redis:
|
||||
image: redis:7
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 6379:6379
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Run tests
|
||||
run: cargo test
|
||||
env:
|
||||
REDIS_URL: redis://localhost:6379
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Run clippy
|
||||
run: |
|
||||
cargo clippy -- \
|
||||
-D warnings \
|
||||
-D clippy::unwrap_used \
|
||||
-D clippy::panic \
|
||||
-D clippy::unimplemented \
|
||||
-D clippy::todo
|
||||
|
||||
fmt:
|
||||
name: Format
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: rustfmt
|
||||
- name: Check formatting
|
||||
run: cargo fmt -- --check
|
||||
|
||||
bench:
|
||||
name: Benchmarks
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
redis:
|
||||
image: redis:7
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 6379:6379
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Build benchmarks
|
||||
run: cargo bench --no-run
|
||||
env:
|
||||
REDIS_URL: redis://localhost:6379
|
||||
3382
autogpt_platform/autogpt-rs/websocket/Cargo.lock
generated
Normal file
3382
autogpt_platform/autogpt-rs/websocket/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
60
autogpt_platform/autogpt-rs/websocket/Cargo.toml
Normal file
60
autogpt_platform/autogpt-rs/websocket/Cargo.toml
Normal file
@@ -0,0 +1,60 @@
|
||||
[package]
|
||||
name = "websocket"
|
||||
authors = ["AutoGPT Team"]
|
||||
description = "WebSocket server for AutoGPT Platform"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "websocket"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "websocket"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.7.5", features = ["ws"] }
|
||||
jsonwebtoken = "9.3.0"
|
||||
redis = { version = "0.25.4", features = ["aio", "tokio-comp"] }
|
||||
serde = { version = "1.0.204", features = ["derive"] }
|
||||
serde_json = "1.0.120"
|
||||
tokio = { version = "1.38.1", features = ["rt-multi-thread", "macros", "net", "sync", "time", "io-util"] }
|
||||
tower-http = { version = "0.5.2", features = ["cors"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
futures = "0.3"
|
||||
dotenvy = "0.15"
|
||||
clap = { version = "4.5.4", features = ["derive"] }
|
||||
toml = "0.8"
|
||||
|
||||
[dev-dependencies]
|
||||
# Load testing and profiling
|
||||
tokio-console = "0.1"
|
||||
criterion = { version = "0.5", features = ["async_tokio"] }
|
||||
pprof = { version = "0.13", features = ["flamegraph", "criterion"] }
|
||||
# Dependencies for benchmarks
|
||||
tokio-tungstenite = "0.24"
|
||||
futures-util = "0.3"
|
||||
chrono = "0.4"
|
||||
|
||||
[[bench]]
|
||||
name = "websocket_bench"
|
||||
harness = false
|
||||
|
||||
[[example]]
|
||||
name = "ws_client_example"
|
||||
required-features = []
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3 # Maximum optimization
|
||||
lto = true # Enable link-time optimization
|
||||
codegen-units = 1 # Reduce parallel code generation units to increase optimization
|
||||
panic = "abort" # Remove panic unwinding to reduce binary size
|
||||
strip = true # Strip symbols from binary
|
||||
|
||||
[profile.bench]
|
||||
opt-level = 3 # Maximum optimization
|
||||
lto = true # Enable link-time optimization
|
||||
codegen-units = 1 # Reduce parallel code generation units to increase optimization
|
||||
debug = true # Keep debug symbols for profiling
|
||||
412
autogpt_platform/autogpt-rs/websocket/README.md
Normal file
412
autogpt_platform/autogpt-rs/websocket/README.md
Normal file
@@ -0,0 +1,412 @@
|
||||
# WebSocket API Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This document provides a complete technical specification for the AutoGPT Platform WebSocket API (`ws_api.py`). The WebSocket API provides real-time updates for graph and node execution events, enabling clients to monitor workflow execution progress.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **WebSocket Server** (`ws_api.py`)
|
||||
- FastAPI application with WebSocket endpoint
|
||||
- Handles client connections and message routing
|
||||
- Authenticates clients via JWT tokens
|
||||
- Manages subscriptions to execution events
|
||||
|
||||
2. **Connection Manager** (`conn_manager.py`)
|
||||
- Maintains active WebSocket connections
|
||||
- Manages channel subscriptions
|
||||
- Routes execution events to subscribed clients
|
||||
- Handles connection lifecycle
|
||||
|
||||
3. **Event Broadcasting System**
|
||||
- Redis Pub/Sub based event bus
|
||||
- Asynchronous event broadcaster
|
||||
- Execution event propagation from backend services
|
||||
|
||||
## API Endpoint
|
||||
|
||||
### WebSocket Endpoint
|
||||
- **URL**: `/ws`
|
||||
- **Protocol**: WebSocket (ws:// or wss://)
|
||||
- **Query Parameters**:
|
||||
- `token` (required when auth enabled): JWT authentication token
|
||||
|
||||
## Authentication
|
||||
|
||||
### JWT Token Authentication
|
||||
- **When Required**: When `settings.config.enable_auth` is `True`
|
||||
- **Token Location**: Query parameter `?token=<JWT_TOKEN>`
|
||||
- **Token Validation**:
|
||||
```python
|
||||
payload = parse_jwt_token(token)
|
||||
user_id = payload.get("sub")
|
||||
```
|
||||
- **JWT Requirements**:
|
||||
- Algorithm: Configured via `settings.JWT_ALGORITHM`
|
||||
- Secret Key: Configured via `settings.JWT_SECRET_KEY`
|
||||
- Audience: Must be "authenticated"
|
||||
- Claims: Must contain `sub` (user ID)
|
||||
|
||||
### Authentication Failures
|
||||
- **4001**: Missing authentication token
|
||||
- **4002**: Invalid token (missing user ID)
|
||||
- **4003**: Invalid token (parsing error or expired)
|
||||
|
||||
### No-Auth Mode
|
||||
- When `settings.config.enable_auth` is `False`
|
||||
- Uses `DEFAULT_USER_ID` from `backend.data.user`
|
||||
|
||||
## Message Protocol
|
||||
|
||||
### Message Format
|
||||
All messages use JSON format with the following structure:
|
||||
|
||||
```typescript
|
||||
interface WSMessage {
|
||||
method: WSMethod;
|
||||
data?: Record<string, any> | any[] | string;
|
||||
success?: boolean;
|
||||
channel?: string;
|
||||
error?: string;
|
||||
}
|
||||
```
|
||||
|
||||
### Message Methods (WSMethod enum)
|
||||
|
||||
1. **Client-to-Server Methods**:
|
||||
- `SUBSCRIBE_GRAPH_EXEC`: Subscribe to specific graph execution
|
||||
- `SUBSCRIBE_GRAPH_EXECS`: Subscribe to all executions of a graph
|
||||
- `UNSUBSCRIBE`: Unsubscribe from a channel
|
||||
- `HEARTBEAT`: Keep-alive ping
|
||||
|
||||
2. **Server-to-Client Methods**:
|
||||
- `GRAPH_EXECUTION_EVENT`: Graph execution status update
|
||||
- `NODE_EXECUTION_EVENT`: Node execution status update
|
||||
- `ERROR`: Error message
|
||||
- `HEARTBEAT`: Keep-alive pong
|
||||
|
||||
## Subscription Models
|
||||
|
||||
### Subscribe to Specific Graph Execution
|
||||
```typescript
|
||||
interface WSSubscribeGraphExecutionRequest {
|
||||
graph_exec_id: string;
|
||||
}
|
||||
```
|
||||
**Channel Key Format**: `{user_id}|graph_exec#{graph_exec_id}`
|
||||
|
||||
### Subscribe to All Graph Executions
|
||||
```typescript
|
||||
interface WSSubscribeGraphExecutionsRequest {
|
||||
graph_id: string;
|
||||
}
|
||||
```
|
||||
**Channel Key Format**: `{user_id}|graph#{graph_id}|executions`
|
||||
|
||||
## Event Models
|
||||
|
||||
### Graph Execution Event
|
||||
```typescript
|
||||
interface GraphExecutionEvent {
|
||||
event_type: "graph_execution_update";
|
||||
id: string; // graph_exec_id
|
||||
user_id: string;
|
||||
graph_id: string;
|
||||
graph_version: number;
|
||||
preset_id?: string;
|
||||
status: ExecutionStatus;
|
||||
started_at: string; // ISO datetime
|
||||
ended_at: string; // ISO datetime
|
||||
inputs: Record<string, any>;
|
||||
outputs: Record<string, any>;
|
||||
stats?: {
|
||||
cost: number; // cents
|
||||
duration: number; // seconds
|
||||
duration_cpu_only: number;
|
||||
node_exec_time: number;
|
||||
node_exec_time_cpu_only: number;
|
||||
node_exec_count: number;
|
||||
node_error_count: number;
|
||||
error?: string;
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
### Node Execution Event
|
||||
```typescript
|
||||
interface NodeExecutionEvent {
|
||||
event_type: "node_execution_update";
|
||||
user_id: string;
|
||||
graph_id: string;
|
||||
graph_version: number;
|
||||
graph_exec_id: string;
|
||||
node_exec_id: string;
|
||||
node_id: string;
|
||||
block_id: string;
|
||||
status: ExecutionStatus;
|
||||
input_data: Record<string, any>;
|
||||
output_data: Record<string, any>;
|
||||
add_time: string; // ISO datetime
|
||||
queue_time?: string; // ISO datetime
|
||||
start_time?: string; // ISO datetime
|
||||
end_time?: string; // ISO datetime
|
||||
}
|
||||
```
|
||||
|
||||
### Execution Status Enum
|
||||
```typescript
|
||||
enum ExecutionStatus {
|
||||
INCOMPLETE = "INCOMPLETE",
|
||||
QUEUED = "QUEUED",
|
||||
RUNNING = "RUNNING",
|
||||
COMPLETED = "COMPLETED",
|
||||
FAILED = "FAILED"
|
||||
}
|
||||
```
|
||||
|
||||
## Message Flow Examples
|
||||
|
||||
### 1. Subscribe to Graph Execution
|
||||
```json
|
||||
// Client → Server
|
||||
{
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": "exec-123"
|
||||
}
|
||||
}
|
||||
|
||||
// Server → Client (Success)
|
||||
{
|
||||
"method": "subscribe_graph_execution",
|
||||
"success": true,
|
||||
"channel": "user-456|graph_exec#exec-123"
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Receive Execution Updates
|
||||
```json
|
||||
// Server → Client (Graph Update)
|
||||
{
|
||||
"method": "graph_execution_event",
|
||||
"channel": "user-456|graph_exec#exec-123",
|
||||
"data": {
|
||||
"event_type": "graph_execution_update",
|
||||
"id": "exec-123",
|
||||
"user_id": "user-456",
|
||||
"graph_id": "graph-789",
|
||||
"status": "RUNNING",
|
||||
// ... other fields
|
||||
}
|
||||
}
|
||||
|
||||
// Server → Client (Node Update)
|
||||
{
|
||||
"method": "node_execution_event",
|
||||
"channel": "user-456|graph_exec#exec-123",
|
||||
"data": {
|
||||
"event_type": "node_execution_update",
|
||||
"node_exec_id": "node-exec-111",
|
||||
"status": "COMPLETED",
|
||||
// ... other fields
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Heartbeat
|
||||
```json
|
||||
// Client → Server
|
||||
{
|
||||
"method": "heartbeat",
|
||||
"data": "ping"
|
||||
}
|
||||
|
||||
// Server → Client
|
||||
{
|
||||
"method": "heartbeat",
|
||||
"data": "pong",
|
||||
"success": true
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Error Handling
|
||||
```json
|
||||
// Server → Client (Invalid Message)
|
||||
{
|
||||
"method": "error",
|
||||
"success": false,
|
||||
"error": "Invalid message format. Review the schema and retry"
|
||||
}
|
||||
```
|
||||
|
||||
## Event Broadcasting Architecture
|
||||
|
||||
### Redis Pub/Sub Integration
|
||||
1. **Event Bus Name**: Configured via `config.execution_event_bus_name`
|
||||
2. **Channel Pattern**: `{event_bus_name}/{channel_key}`
|
||||
3. **Event Flow**:
|
||||
- Execution services publish events to Redis
|
||||
- Event broadcaster listens to Redis pattern `*`
|
||||
- Events are routed to WebSocket connections based on subscriptions
|
||||
|
||||
### Event Broadcaster
|
||||
- Runs as continuous async task using `@continuous_retry()` decorator
|
||||
- Listens to all execution events via `AsyncRedisExecutionEventBus`
|
||||
- Calls `ConnectionManager.send_execution_update()` for each event
|
||||
|
||||
## Connection Lifecycle
|
||||
|
||||
### Connection Establishment
|
||||
1. Client connects to `/ws` endpoint
|
||||
2. Authentication performed (JWT validation)
|
||||
3. WebSocket accepted via `manager.connect_socket()`
|
||||
4. Connection added to active connections set
|
||||
|
||||
### Message Processing Loop
|
||||
1. Receive text message from client
|
||||
2. Parse and validate as `WSMessage`
|
||||
3. Route to appropriate handler based on `method`
|
||||
4. Send response or error back to client
|
||||
|
||||
### Connection Termination
|
||||
1. `WebSocketDisconnect` exception caught
|
||||
2. `manager.disconnect_socket()` called
|
||||
3. Connection removed from active connections
|
||||
4. All subscriptions for that connection removed
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Validation Errors
|
||||
- **Invalid Message Format**: Returns error with method "error"
|
||||
- **Invalid Message Data**: Returns error with specific validation message
|
||||
- **Unknown Message Type**: Returns error indicating unsupported method
|
||||
|
||||
### Connection Errors
|
||||
- WebSocket disconnections handled gracefully
|
||||
- Failed event parsing logged but doesn't crash connection
|
||||
- Handler exceptions logged and connection continues
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
```python
|
||||
# WebSocket Server Configuration
|
||||
websocket_server_host: str = "0.0.0.0"
|
||||
websocket_server_port: int = 8001
|
||||
|
||||
# Authentication
|
||||
enable_auth: bool = True
|
||||
|
||||
# CORS
|
||||
backend_cors_allow_origins: List[str] = []
|
||||
|
||||
# Redis Event Bus
|
||||
execution_event_bus_name: str = "autogpt:execution_event_bus"
|
||||
|
||||
# Message Size Limits
|
||||
max_message_size_limit: int = 512000 # 512KB
|
||||
```
|
||||
|
||||
### Security Headers
|
||||
- CORS middleware applied with configured origins
|
||||
- Credentials allowed for authenticated requests
|
||||
- All methods and headers allowed (configurable)
|
||||
|
||||
## Deployment Requirements
|
||||
|
||||
### Dependencies
|
||||
1. **FastAPI**: Web framework with WebSocket support
|
||||
2. **Redis**: For pub/sub event broadcasting
|
||||
3. **JWT Libraries**: For token validation
|
||||
4. **Prisma**: Database ORM (for future graph access validation)
|
||||
|
||||
### Process Management
|
||||
- Implements `AppProcess` interface for service lifecycle
|
||||
- Runs via `uvicorn` ASGI server
|
||||
- Graceful shutdown handling in `cleanup()` method
|
||||
|
||||
### Concurrent Connections
|
||||
- No hard limit on WebSocket connections
|
||||
- Memory usage scales with active connections
|
||||
- Each connection maintains subscription set
|
||||
|
||||
## Implementation Checklist
|
||||
|
||||
To implement a compatible WebSocket API:
|
||||
|
||||
1. **Authentication**
|
||||
- [ ] JWT token validation from query parameters
|
||||
- [ ] Support for no-auth mode with default user ID
|
||||
- [ ] Proper error codes for auth failures
|
||||
|
||||
2. **Message Handling**
|
||||
- [ ] Parse and validate WSMessage format
|
||||
- [ ] Implement all client-to-server methods
|
||||
- [ ] Support all server-to-client event types
|
||||
- [ ] Proper error responses for invalid messages
|
||||
|
||||
3. **Subscription Management**
|
||||
- [ ] Channel key generation matching exact format
|
||||
- [ ] Support for both execution and graph-level subscriptions
|
||||
- [ ] Unsubscribe functionality
|
||||
- [ ] Clean up subscriptions on disconnect
|
||||
|
||||
4. **Event Broadcasting**
|
||||
- [ ] Listen to Redis pub/sub for execution events
|
||||
- [ ] Route events to correct subscribed connections
|
||||
- [ ] Handle both graph and node execution events
|
||||
- [ ] Maintain event order and completeness
|
||||
|
||||
5. **Connection Management**
|
||||
- [ ] Track active WebSocket connections
|
||||
- [ ] Handle graceful disconnections
|
||||
- [ ] Implement heartbeat/keepalive
|
||||
- [ ] Memory-efficient subscription storage
|
||||
|
||||
6. **Configuration**
|
||||
- [ ] Support all environment variables
|
||||
- [ ] CORS configuration for allowed origins
|
||||
- [ ] Configurable host/port binding
|
||||
- [ ] Redis connection configuration
|
||||
|
||||
7. **Error Handling**
|
||||
- [ ] Graceful handling of malformed messages
|
||||
- [ ] Logging of errors without dropping connections
|
||||
- [ ] Specific error messages for debugging
|
||||
- [ ] Recovery from Redis connection issues
|
||||
|
||||
## Testing Considerations
|
||||
|
||||
1. **Unit Tests**
|
||||
- Message parsing and validation
|
||||
- Channel key generation
|
||||
- Subscription management logic
|
||||
|
||||
2. **Integration Tests**
|
||||
- Full WebSocket connection flow
|
||||
- Event broadcasting from Redis
|
||||
- Multi-client subscription scenarios
|
||||
- Authentication success/failure cases
|
||||
|
||||
3. **Load Tests**
|
||||
- Many concurrent connections
|
||||
- High-frequency event broadcasting
|
||||
- Memory usage under load
|
||||
- Connection/disconnection cycles
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Authentication**: JWT tokens transmitted via query parameters (consider upgrading to headers)
|
||||
2. **Authorization**: Currently no graph-level access validation (commented out in code)
|
||||
3. **Rate Limiting**: No rate limiting implemented
|
||||
4. **Message Size**: Limited by `max_message_size_limit` configuration
|
||||
5. **Input Validation**: All inputs validated via Pydantic models
|
||||
|
||||
## Future Enhancements (Currently Commented Out)
|
||||
|
||||
1. **Graph Access Validation**: Verify user has read access to subscribed graphs
|
||||
2. **Message Compression**: For large execution payloads
|
||||
3. **Batch Updates**: Aggregate multiple events in single message
|
||||
4. **Selective Field Subscription**: Subscribe to specific fields only
|
||||
93
autogpt_platform/autogpt-rs/websocket/benches/README.md
Normal file
93
autogpt_platform/autogpt-rs/websocket/benches/README.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# WebSocket Server Benchmarks
|
||||
|
||||
This directory contains performance benchmarks for the AutoGPT WebSocket server.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Redis must be running locally or set `REDIS_URL` environment variable:
|
||||
```bash
|
||||
docker run -d -p 6379:6379 redis:latest
|
||||
```
|
||||
|
||||
2. Build the project in release mode:
|
||||
```bash
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
Run all benchmarks:
|
||||
```bash
|
||||
cargo bench
|
||||
```
|
||||
|
||||
Run specific benchmark group:
|
||||
```bash
|
||||
cargo bench connection_establishment
|
||||
cargo bench subscriptions
|
||||
cargo bench message_throughput
|
||||
cargo bench concurrent_connections
|
||||
cargo bench message_parsing
|
||||
cargo bench redis_event_processing
|
||||
```
|
||||
|
||||
## Benchmark Categories
|
||||
|
||||
### Connection Establishment
|
||||
Tests the performance of establishing WebSocket connections with different authentication scenarios:
|
||||
- No authentication
|
||||
- Valid JWT authentication
|
||||
- Invalid JWT authentication (connection rejection)
|
||||
|
||||
### Subscriptions
|
||||
Measures the performance of subscription operations:
|
||||
- Subscribing to graph execution events
|
||||
- Unsubscribing from channels
|
||||
|
||||
### Message Throughput
|
||||
Tests how many messages the server can process per second with varying message counts (10, 100, 1000).
|
||||
|
||||
### Concurrent Connections
|
||||
Benchmarks the server's ability to handle multiple simultaneous connections (10, 50, 100, 500 clients).
|
||||
|
||||
### Message Parsing
|
||||
Tests JSON parsing performance with different message sizes (100B to 100KB).
|
||||
|
||||
### Redis Event Processing
|
||||
Benchmarks the parsing of execution events received from Redis.
|
||||
|
||||
## Profiling
|
||||
|
||||
To generate flamegraphs for CPU profiling:
|
||||
|
||||
1. Install flamegraph tools:
|
||||
```bash
|
||||
cargo install flamegraph
|
||||
```
|
||||
|
||||
2. Run benchmarks with profiling:
|
||||
```bash
|
||||
cargo bench --bench websocket_bench -- --profile-time=10
|
||||
```
|
||||
|
||||
## Interpreting Results
|
||||
|
||||
- **Throughput**: Higher is better (operations/second or elements/second)
|
||||
- **Time**: Lower is better (nanoseconds per operation)
|
||||
- **Error margins**: Look for stable results with low standard deviation
|
||||
|
||||
## Optimizing Performance
|
||||
|
||||
Based on benchmark results, consider:
|
||||
|
||||
1. **Connection pooling** for Redis connections
|
||||
2. **Message batching** for high-throughput scenarios
|
||||
3. **Async task tuning** for concurrent connection handling
|
||||
4. **JSON parsing optimization** using simd-json or other fast parsers
|
||||
5. **Memory allocation** optimization using arena allocators
|
||||
|
||||
## Notes
|
||||
|
||||
- Benchmarks create actual WebSocket servers on random ports
|
||||
- Each benchmark iteration properly cleans up resources
|
||||
- Results may vary based on system resources and Redis performance
|
||||
406
autogpt_platform/autogpt-rs/websocket/benches/websocket_bench.rs
Normal file
406
autogpt_platform/autogpt-rs/websocket/benches/websocket_bench.rs
Normal file
@@ -0,0 +1,406 @@
|
||||
#![allow(clippy::unwrap_used)] // Benchmarks can panic on setup errors
|
||||
|
||||
use axum::{routing::get, Router};
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
// Import the actual websocket server components
|
||||
use websocket::{models, ws_handler, AppState, Config, ConnectionManager, Stats};
|
||||
|
||||
// Helper to create a test server
|
||||
async fn create_test_server(enable_auth: bool) -> (String, tokio::task::JoinHandle<()>) {
|
||||
// Set environment variables for test config
|
||||
std::env::set_var("WEBSOCKET_SERVER_HOST", "127.0.0.1");
|
||||
std::env::set_var("WEBSOCKET_SERVER_PORT", "0");
|
||||
std::env::set_var("ENABLE_AUTH", enable_auth.to_string());
|
||||
std::env::set_var("SUPABASE_JWT_SECRET", "test_secret");
|
||||
std::env::set_var("DEFAULT_USER_ID", "test_user");
|
||||
if std::env::var("REDIS_URL").is_err() {
|
||||
std::env::set_var("REDIS_URL", "redis://localhost:6379");
|
||||
}
|
||||
|
||||
let mut config = Config::load(None);
|
||||
config.port = 0; // Force OS to assign port
|
||||
|
||||
let redis_client =
|
||||
redis::Client::open(config.redis_url.clone()).expect("Failed to connect to Redis");
|
||||
let stats = Arc::new(Stats::default());
|
||||
let mgr = Arc::new(ConnectionManager::new(
|
||||
redis_client,
|
||||
config.execution_event_bus_name.clone(),
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
// Start broadcaster
|
||||
let mgr_clone = mgr.clone();
|
||||
tokio::spawn(async move {
|
||||
mgr_clone.run_broadcaster().await;
|
||||
});
|
||||
|
||||
let state = AppState {
|
||||
mgr,
|
||||
config: Arc::new(config),
|
||||
stats,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/ws", get(ws_handler))
|
||||
.layer(axum::Extension(state));
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let server_url = format!("ws://{addr}");
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
// Give server time to start
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
(server_url, server_handle)
|
||||
}
|
||||
|
||||
// Helper to create a valid JWT token
|
||||
fn create_jwt_token(user_id: &str) -> String {
|
||||
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
aud: Vec<String>,
|
||||
exp: usize,
|
||||
}
|
||||
|
||||
let claims = Claims {
|
||||
sub: user_id.to_string(),
|
||||
aud: vec!["authenticated".to_string()],
|
||||
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp() as usize,
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::new(Algorithm::HS256),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(b"test_secret"),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
// Benchmark connection establishment
|
||||
fn benchmark_connection_establishment(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
|
||||
let mut group = c.benchmark_group("connection_establishment");
|
||||
group.measurement_time(Duration::from_secs(30));
|
||||
|
||||
// Test without auth
|
||||
group.bench_function("no_auth", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
let (ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
drop(ws_stream);
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
// Test with valid auth
|
||||
group.bench_function("valid_auth", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(true).await;
|
||||
let token = create_jwt_token("test_user");
|
||||
let url = format!("{server_url}/ws?token={token}");
|
||||
let (ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
drop(ws_stream);
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
// Test with invalid auth
|
||||
group.bench_function("invalid_auth", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(true).await;
|
||||
let url = format!("{server_url}/ws?token=invalid");
|
||||
let result = connect_async(&url).await;
|
||||
assert!(
|
||||
result.is_err() || {
|
||||
if let Ok((mut ws_stream, _)) = result {
|
||||
// Should receive close frame
|
||||
matches!(ws_stream.next().await, Some(Ok(Message::Close(_))))
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
);
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark subscription operations
|
||||
fn benchmark_subscriptions(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
|
||||
let mut group = c.benchmark_group("subscriptions");
|
||||
group.measurement_time(Duration::from_secs(20));
|
||||
|
||||
group.bench_function("subscribe_graph_execution", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
let msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": "test_exec_123"
|
||||
}
|
||||
});
|
||||
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for response
|
||||
if let Some(Ok(Message::Text(response))) = ws_stream.next().await {
|
||||
let resp: serde_json::Value = serde_json::from_str(&response).unwrap();
|
||||
assert_eq!(resp["success"], true);
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("unsubscribe", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
|
||||
// First subscribe
|
||||
let msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": "test_exec_123"
|
||||
}
|
||||
});
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
ws_stream.next().await; // Consume response
|
||||
let msg = json!({
|
||||
"method": "unsubscribe",
|
||||
"data": {
|
||||
"channel": "test_user|graph_exec#test_exec_123"
|
||||
}
|
||||
});
|
||||
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for response
|
||||
if let Some(Ok(Message::Text(response))) = ws_stream.next().await {
|
||||
let resp: serde_json::Value = serde_json::from_str(&response).unwrap();
|
||||
assert_eq!(resp["success"], true);
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark message throughput
|
||||
fn benchmark_message_throughput(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
|
||||
let mut group = c.benchmark_group("message_throughput");
|
||||
group.measurement_time(Duration::from_secs(30));
|
||||
|
||||
for msg_count in [10, 100, 1000].iter() {
|
||||
group.throughput(Throughput::Elements(*msg_count as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(msg_count),
|
||||
msg_count,
|
||||
|b, &msg_count| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
// Send multiple heartbeat messages
|
||||
for _ in 0..msg_count {
|
||||
let msg = json!({
|
||||
"method": "heartbeat",
|
||||
"data": "ping"
|
||||
});
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Receive all responses
|
||||
for _ in 0..msg_count {
|
||||
ws_stream.next().await;
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark concurrent connections
|
||||
fn benchmark_concurrent_connections(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
|
||||
let mut group = c.benchmark_group("concurrent_connections");
|
||||
group.measurement_time(Duration::from_secs(60));
|
||||
group.sample_size(10);
|
||||
|
||||
for num_clients in [100, 500, 1000].iter() {
|
||||
group.throughput(Throughput::Elements(*num_clients as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(num_clients),
|
||||
num_clients,
|
||||
|b, &num_clients| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
|
||||
// Create multiple concurrent connections
|
||||
let mut handles = vec![];
|
||||
for i in 0..num_clients {
|
||||
let url = url.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
|
||||
// Subscribe to a unique channel
|
||||
let msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": format!("exec_{}", i)
|
||||
}
|
||||
});
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
ws_stream.next().await; // Wait for response
|
||||
|
||||
// Send a heartbeat
|
||||
let msg = json!({
|
||||
"method": "heartbeat",
|
||||
"data": "ping"
|
||||
});
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
ws_stream.next().await; // Wait for response
|
||||
|
||||
ws_stream
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all connections to complete
|
||||
for handle in handles {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark message parsing
|
||||
fn benchmark_message_parsing(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("message_parsing");
|
||||
|
||||
// Test different message sizes
|
||||
for msg_size in [100, 1000, 10000].iter() {
|
||||
group.throughput(Throughput::Bytes(*msg_size as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("parse_json", msg_size),
|
||||
msg_size,
|
||||
|b, &msg_size| {
|
||||
let data_str = "x".repeat(msg_size);
|
||||
let json_msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": data_str
|
||||
}
|
||||
});
|
||||
let json_str = json_msg.to_string();
|
||||
|
||||
b.iter(|| {
|
||||
let _: models::WSMessage = serde_json::from_str(&json_str).unwrap();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark Redis event processing
|
||||
fn benchmark_redis_event_processing(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("redis_event_processing");
|
||||
|
||||
group.bench_function("parse_execution_event", |b| {
|
||||
let event = json!({
|
||||
"payload": {
|
||||
"event_type": "graph_execution_update",
|
||||
"id": "exec_123",
|
||||
"graph_id": "graph_456",
|
||||
"graph_version": 1,
|
||||
"user_id": "user_789",
|
||||
"status": "RUNNING",
|
||||
"started_at": "2024-01-01T00:00:00Z",
|
||||
"inputs": {"test": "data"},
|
||||
"outputs": {}
|
||||
}
|
||||
});
|
||||
let event_str = event.to_string();
|
||||
|
||||
b.iter(|| {
|
||||
let _: models::RedisEventWrapper = serde_json::from_str(&event_str).unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
benchmark_connection_establishment,
|
||||
benchmark_subscriptions,
|
||||
benchmark_message_throughput,
|
||||
benchmark_concurrent_connections,
|
||||
benchmark_message_parsing,
|
||||
benchmark_redis_event_processing
|
||||
);
|
||||
criterion_main!(benches);
|
||||
10
autogpt_platform/autogpt-rs/websocket/clippy.toml
Normal file
10
autogpt_platform/autogpt-rs/websocket/clippy.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
# Clippy configuration for robust error handling
|
||||
|
||||
# Set the maximum cognitive complexity allowed
|
||||
cognitive-complexity-threshold = 30
|
||||
|
||||
# Warn on TODO/FIXME comments
|
||||
allow-dbg-in-tests = false
|
||||
|
||||
# Enforce documentation
|
||||
missing-docs-in-crate-items = true
|
||||
23
autogpt_platform/autogpt-rs/websocket/config.toml
Normal file
23
autogpt_platform/autogpt-rs/websocket/config.toml
Normal file
@@ -0,0 +1,23 @@
|
||||
# WebSocket API Configuration
|
||||
|
||||
# Server settings
|
||||
host = "0.0.0.0"
|
||||
port = 8001
|
||||
|
||||
# Authentication
|
||||
enable_auth = true
|
||||
jwt_secret = "your-super-secret-jwt-token-with-at-least-32-characters-long"
|
||||
jwt_algorithm = "HS256"
|
||||
default_user_id = "default"
|
||||
|
||||
# Redis configuration
|
||||
redis_url = "redis://:password@localhost:6379/"
|
||||
|
||||
# Event bus
|
||||
execution_event_bus_name = "execution_event"
|
||||
|
||||
# Message size limit (in bytes)
|
||||
max_message_size_limit = 512000
|
||||
|
||||
# CORS allowed origins
|
||||
backend_cors_allow_origins = ["http://localhost:3000", "https://559f69c159ef.ngrok.app"]
|
||||
@@ -0,0 +1,75 @@
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde_json::json;
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = "ws://localhost:8001/ws";
|
||||
|
||||
println!("Connecting to {url}");
|
||||
let (mut ws_stream, _) = connect_async(url).await?;
|
||||
println!("Connected!");
|
||||
|
||||
// Subscribe to a graph execution
|
||||
let subscribe_msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": "test_exec_123"
|
||||
}
|
||||
});
|
||||
|
||||
println!("Sending subscription request...");
|
||||
ws_stream
|
||||
.send(Message::Text(subscribe_msg.to_string()))
|
||||
.await?;
|
||||
|
||||
// Wait for response
|
||||
if let Some(msg) = ws_stream.next().await {
|
||||
if let Message::Text(text) = msg? {
|
||||
println!("Received: {text}");
|
||||
}
|
||||
}
|
||||
|
||||
// Send heartbeat
|
||||
let heartbeat_msg = json!({
|
||||
"method": "heartbeat",
|
||||
"data": "ping"
|
||||
});
|
||||
|
||||
println!("Sending heartbeat...");
|
||||
ws_stream
|
||||
.send(Message::Text(heartbeat_msg.to_string()))
|
||||
.await?;
|
||||
|
||||
// Wait for pong
|
||||
if let Some(msg) = ws_stream.next().await {
|
||||
if let Message::Text(text) = msg? {
|
||||
println!("Received: {text}");
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe
|
||||
let unsubscribe_msg = json!({
|
||||
"method": "unsubscribe",
|
||||
"data": {
|
||||
"channel": "default|graph_exec#test_exec_123"
|
||||
}
|
||||
});
|
||||
|
||||
println!("Sending unsubscribe request...");
|
||||
ws_stream
|
||||
.send(Message::Text(unsubscribe_msg.to_string()))
|
||||
.await?;
|
||||
|
||||
// Wait for response
|
||||
if let Some(msg) = ws_stream.next().await {
|
||||
if let Message::Text(text) = msg? {
|
||||
println!("Received: {text}");
|
||||
}
|
||||
}
|
||||
|
||||
println!("Closing connection...");
|
||||
ws_stream.close(None).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
99
autogpt_platform/autogpt-rs/websocket/src/config.rs
Normal file
99
autogpt_platform/autogpt-rs/websocket/src/config.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use jsonwebtoken::Algorithm;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Config {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub enable_auth: bool,
|
||||
pub jwt_secret: String,
|
||||
pub jwt_algorithm: Algorithm,
|
||||
pub execution_event_bus_name: String,
|
||||
pub redis_url: String,
|
||||
pub default_user_id: String,
|
||||
pub max_message_size_limit: usize,
|
||||
pub backend_cors_allow_origins: Vec<String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load(config_path: Option<&Path>) -> Self {
|
||||
let path = config_path.unwrap_or(Path::new("config.toml"));
|
||||
let toml_result = fs::read_to_string(path)
|
||||
.ok()
|
||||
.and_then(|s| toml::from_str::<Config>(&s).ok());
|
||||
|
||||
let mut config = match toml_result {
|
||||
Some(config) => config,
|
||||
None => Config {
|
||||
host: env::var("WEBSOCKET_SERVER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()),
|
||||
port: env::var("WEBSOCKET_SERVER_PORT")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8001),
|
||||
enable_auth: env::var("ENABLE_AUTH")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(true),
|
||||
jwt_secret: env::var("SUPABASE_JWT_SECRET")
|
||||
.unwrap_or_else(|_| "dummy_secret_for_no_auth".to_string()),
|
||||
jwt_algorithm: Algorithm::HS256,
|
||||
execution_event_bus_name: env::var("EXECUTION_EVENT_BUS_NAME")
|
||||
.unwrap_or_else(|_| "execution_event".to_string()),
|
||||
redis_url: env::var("REDIS_URL")
|
||||
.unwrap_or_else(|_| "redis://localhost/".to_string()),
|
||||
default_user_id: "default".to_string(),
|
||||
max_message_size_limit: env::var("MAX_MESSAGE_SIZE_LIMIT")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(512000),
|
||||
backend_cors_allow_origins: env::var("BACKEND_CORS_ALLOW_ORIGINS")
|
||||
.unwrap_or_default()
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
if let Ok(v) = env::var("WEBSOCKET_SERVER_HOST") {
|
||||
config.host = v;
|
||||
}
|
||||
if let Ok(v) = env::var("WEBSOCKET_SERVER_PORT") {
|
||||
config.port = v.parse().unwrap_or(8001);
|
||||
}
|
||||
if let Ok(v) = env::var("ENABLE_AUTH") {
|
||||
config.enable_auth = v.parse().unwrap_or(true);
|
||||
}
|
||||
if let Ok(v) = env::var("SUPABASE_JWT_SECRET") {
|
||||
config.jwt_secret = v;
|
||||
}
|
||||
if let Ok(v) = env::var("JWT_ALGORITHM") {
|
||||
config.jwt_algorithm = Algorithm::from_str(&v).unwrap_or(Algorithm::HS256);
|
||||
}
|
||||
if let Ok(v) = env::var("EXECUTION_EVENT_BUS_NAME") {
|
||||
config.execution_event_bus_name = v;
|
||||
}
|
||||
if let Ok(v) = env::var("REDIS_URL") {
|
||||
config.redis_url = v;
|
||||
}
|
||||
if let Ok(v) = env::var("DEFAULT_USER_ID") {
|
||||
config.default_user_id = v;
|
||||
}
|
||||
if let Ok(v) = env::var("MAX_MESSAGE_SIZE_LIMIT") {
|
||||
config.max_message_size_limit = v.parse().unwrap_or(512000);
|
||||
}
|
||||
if let Ok(v) = env::var("BACKEND_CORS_ALLOW_ORIGINS") {
|
||||
config.backend_cors_allow_origins = v
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
277
autogpt_platform/autogpt-rs/websocket/src/connection_manager.rs
Normal file
277
autogpt_platform/autogpt-rs/websocket/src/connection_manager.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
use futures::StreamExt;
|
||||
use redis::Client as RedisClient;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::models::{ExecutionEvent, RedisEventWrapper, WSMessage};
|
||||
use crate::stats::Stats;
|
||||
|
||||
pub struct ConnectionManager {
|
||||
pub subscribers: RwLock<HashMap<String, HashSet<u64>>>,
|
||||
pub clients: RwLock<HashMap<u64, (String, mpsc::Sender<String>)>>,
|
||||
pub client_channels: RwLock<HashMap<u64, HashSet<String>>>,
|
||||
pub next_id: AtomicU64,
|
||||
pub redis_client: RedisClient,
|
||||
pub bus_name: String,
|
||||
pub stats: Arc<Stats>,
|
||||
}
|
||||
|
||||
impl ConnectionManager {
|
||||
pub fn new(redis_client: RedisClient, bus_name: String, stats: Arc<Stats>) -> Self {
|
||||
Self {
|
||||
subscribers: RwLock::new(HashMap::new()),
|
||||
clients: RwLock::new(HashMap::new()),
|
||||
client_channels: RwLock::new(HashMap::new()),
|
||||
next_id: AtomicU64::new(0),
|
||||
redis_client,
|
||||
bus_name,
|
||||
stats,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_broadcaster(self: Arc<Self>) {
|
||||
info!("🚀 Starting Redis event broadcaster");
|
||||
|
||||
loop {
|
||||
match self.run_broadcaster_inner().await {
|
||||
Ok(_) => {
|
||||
warn!("⚠️ Event broadcaster stopped unexpectedly, restarting in 5 seconds");
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("❌ Event broadcaster error: {}, restarting in 5 seconds", e);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_broadcaster_inner(
|
||||
self: &Arc<Self>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut pubsub = self.redis_client.get_async_pubsub().await?;
|
||||
pubsub.psubscribe("*").await?;
|
||||
info!(
|
||||
"📡 Listening to all Redis events, filtering for bus: {}",
|
||||
self.bus_name
|
||||
);
|
||||
|
||||
let mut pubsub_stream = pubsub.on_message();
|
||||
|
||||
loop {
|
||||
let msg = pubsub_stream.next().await;
|
||||
match msg {
|
||||
Some(msg) => {
|
||||
let channel: String = msg.get_channel_name().to_string();
|
||||
debug!("📨 Received message on Redis channel: {}", channel);
|
||||
self.stats
|
||||
.redis_messages_received
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let payload: String = match msg.get_payload() {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
warn!("⚠️ Failed to get payload from Redis message: {}", e);
|
||||
self.stats
|
||||
.errors_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the channel format: execution_event/{user_id}/{graph_id}/{graph_exec_id}
|
||||
let parts: Vec<&str> = channel.split('/').collect();
|
||||
|
||||
// Check if this is an execution event channel
|
||||
if parts.len() != 4 || parts[0] != self.bus_name {
|
||||
debug!(
|
||||
"🚫 Ignoring non-execution event channel: {} (parts: {:?}, bus_name: {})",
|
||||
channel, parts, self.bus_name
|
||||
);
|
||||
self.stats
|
||||
.redis_messages_ignored
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
|
||||
let user_id = parts[1];
|
||||
let graph_id = parts[2];
|
||||
let graph_exec_id = parts[3];
|
||||
|
||||
debug!(
|
||||
"📥 Received event - user: {}, graph: {}, exec: {}",
|
||||
user_id, graph_id, graph_exec_id
|
||||
);
|
||||
|
||||
// Parse the wrapped event
|
||||
let wrapped_event = match RedisEventWrapper::parse(&payload) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
warn!("⚠️ Failed to parse event JSON: {}, payload: {}", e, payload);
|
||||
self.stats
|
||||
.errors_json_parse
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
self.stats
|
||||
.errors_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let event = wrapped_event.payload;
|
||||
debug!("📦 Event received: {:?}", event);
|
||||
|
||||
let (method, event_json) = match &event {
|
||||
ExecutionEvent::GraphExecutionUpdate(graph_event) => {
|
||||
self.stats
|
||||
.graph_execution_events
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
self.stats
|
||||
.events_received_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
(
|
||||
"graph_execution_event",
|
||||
match serde_json::to_value(graph_event) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("❌ Failed to serialize graph event: {}", e);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
ExecutionEvent::NodeExecutionUpdate(node_event) => {
|
||||
self.stats
|
||||
.node_execution_events
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
self.stats
|
||||
.events_received_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
(
|
||||
"node_execution_event",
|
||||
match serde_json::to_value(node_event) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("❌ Failed to serialize node event: {}", e);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Create the channel keys in the format expected by WebSocket clients
|
||||
let mut channels_to_notify = Vec::new();
|
||||
|
||||
// For both event types, notify the specific execution channel
|
||||
let exec_channel = format!("{user_id}|graph_exec#{graph_exec_id}");
|
||||
channels_to_notify.push(exec_channel.clone());
|
||||
|
||||
// For graph execution events, also notify the graph executions channel
|
||||
if matches!(&event, ExecutionEvent::GraphExecutionUpdate(_)) {
|
||||
let graph_channel = format!("{user_id}|graph#{graph_id}|executions");
|
||||
channels_to_notify.push(graph_channel);
|
||||
}
|
||||
|
||||
debug!(
|
||||
"📢 Broadcasting {} event to channels: {:?}",
|
||||
method, channels_to_notify
|
||||
);
|
||||
|
||||
let subs = self.subscribers.read().await;
|
||||
|
||||
// Log current subscriber state
|
||||
debug!("📊 Current subscribers count: {}", subs.len());
|
||||
|
||||
for channel_key in channels_to_notify {
|
||||
let ws_msg = WSMessage {
|
||||
method: method.to_string(),
|
||||
channel: Some(channel_key.clone()),
|
||||
data: Some(event_json.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let json_msg = match serde_json::to_string(&ws_msg) {
|
||||
Ok(j) => {
|
||||
debug!("📤 Sending WebSocket message: {}", j);
|
||||
j
|
||||
}
|
||||
Err(e) => {
|
||||
error!("❌ Failed to serialize WebSocket message: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(client_ids) = subs.get(&channel_key) {
|
||||
let clients = self.clients.read().await;
|
||||
let client_count = client_ids.len();
|
||||
debug!(
|
||||
"📣 Broadcasting to {} clients on channel: {}",
|
||||
client_count, channel_key
|
||||
);
|
||||
|
||||
for &cid in client_ids {
|
||||
if let Some((user_id, tx)) = clients.get(&cid) {
|
||||
match tx.try_send(json_msg.clone()) {
|
||||
Ok(_) => {
|
||||
debug!(
|
||||
"✅ Message sent immediately to client {} (user: {})",
|
||||
cid, user_id
|
||||
);
|
||||
self.stats
|
||||
.messages_sent_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
// Channel is full, try with a small timeout
|
||||
let tx_clone = tx.clone();
|
||||
let msg_clone = json_msg.clone();
|
||||
let stats_clone = self.stats.clone();
|
||||
tokio::spawn(async move {
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(100),
|
||||
tx_clone.send(msg_clone),
|
||||
)
|
||||
.await {
|
||||
Ok(Ok(_)) => {
|
||||
stats_clone
|
||||
.messages_sent_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
_ => {
|
||||
stats_clone
|
||||
.messages_failed_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
});
|
||||
warn!("⚠️ Channel full for client {} (user: {}), sending async", cid, user_id);
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
warn!(
|
||||
"⚠️ Channel closed for client {} (user: {})",
|
||||
cid, user_id
|
||||
);
|
||||
self.stats
|
||||
.messages_failed_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("⚠️ Client {} not found in clients map", cid);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("📭 No subscribers for channel: {}", channel_key);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err("❌ Redis pubsub stream ended".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
442
autogpt_platform/autogpt-rs/websocket/src/handlers.rs
Normal file
442
autogpt_platform/autogpt-rs/websocket/src/handlers.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
use axum::extract::ws::{CloseFrame, Message, WebSocket};
|
||||
use axum::{
|
||||
extract::{Query, WebSocketUpgrade},
|
||||
http::HeaderMap,
|
||||
response::IntoResponse,
|
||||
Extension,
|
||||
};
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::connection_manager::ConnectionManager;
|
||||
use crate::models::{Claims, WSMessage};
|
||||
use crate::AppState;
|
||||
|
||||
// Helper function to safely serialize messages
|
||||
fn serialize_message(msg: &WSMessage) -> String {
|
||||
serde_json::to_string(msg).unwrap_or_else(|e| {
|
||||
error!("❌ Failed to serialize WebSocket message: {}", e);
|
||||
json!({"method": "error", "success": false, "error": "Internal serialization error"})
|
||||
.to_string()
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
query: Query<HashMap<String, String>>,
|
||||
_headers: HeaderMap,
|
||||
Extension(state): Extension<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let token = query.0.get("token").cloned();
|
||||
let mut user_id = state.config.default_user_id.clone();
|
||||
let mut auth_error_code: Option<u16> = None;
|
||||
|
||||
if state.config.enable_auth {
|
||||
match token {
|
||||
Some(token_str) => {
|
||||
debug!("🔐 Authenticating WebSocket connection");
|
||||
let mut validation = Validation::new(state.config.jwt_algorithm);
|
||||
validation.set_audience(&["authenticated"]);
|
||||
|
||||
let key = DecodingKey::from_secret(state.config.jwt_secret.as_bytes());
|
||||
|
||||
match decode::<Claims>(&token_str, &key, &validation) {
|
||||
Ok(token_data) => {
|
||||
user_id = token_data.claims.sub.clone();
|
||||
debug!("✅ WebSocket authenticated for user: {}", user_id);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("⚠️ JWT validation failed: {}", e);
|
||||
auth_error_code = Some(4003);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("⚠️ Missing authentication token in WebSocket connection");
|
||||
auth_error_code = Some(4001);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("🔓 WebSocket connection without auth (auth disabled)");
|
||||
}
|
||||
|
||||
if let Some(code) = auth_error_code {
|
||||
error!("❌ WebSocket authentication failed with code: {}", code);
|
||||
state
|
||||
.mgr
|
||||
.stats
|
||||
.connections_failed_auth
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
state
|
||||
.mgr
|
||||
.stats
|
||||
.connections_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
return ws
|
||||
.on_upgrade(move |mut socket: WebSocket| async move {
|
||||
let close_frame = Some(CloseFrame {
|
||||
code,
|
||||
reason: "Authentication failed".into(),
|
||||
});
|
||||
let _ = socket.send(Message::Close(close_frame)).await;
|
||||
let _ = socket.close().await;
|
||||
})
|
||||
.into_response();
|
||||
}
|
||||
|
||||
debug!("✅ WebSocket connection established for user: {}", user_id);
|
||||
ws.on_upgrade(move |socket| {
|
||||
handle_socket(
|
||||
socket,
|
||||
user_id,
|
||||
state.mgr.clone(),
|
||||
state.config.max_message_size_limit,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn update_subscription_stats(mgr: &ConnectionManager, channel: &str, add: bool) {
|
||||
if add {
|
||||
mgr.stats
|
||||
.subscriptions_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
mgr.stats
|
||||
.subscriptions_active
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let mut channel_stats = mgr.stats.channels_active.write().await;
|
||||
let count = channel_stats.entry(channel.to_string()).or_insert(0);
|
||||
*count += 1;
|
||||
} else {
|
||||
mgr.stats
|
||||
.unsubscriptions_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
mgr.stats
|
||||
.subscriptions_active
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let mut channel_stats = mgr.stats.channels_active.write().await;
|
||||
if let Some(count) = channel_stats.get_mut(channel) {
|
||||
*count = count.saturating_sub(1);
|
||||
if *count == 0 {
|
||||
channel_stats.remove(channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_socket(
|
||||
mut socket: WebSocket,
|
||||
user_id: String,
|
||||
mgr: std::sync::Arc<ConnectionManager>,
|
||||
max_size: usize,
|
||||
) {
|
||||
let client_id = mgr
|
||||
.next_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let (tx, mut rx) = mpsc::channel::<String>(10);
|
||||
info!("👋 New WebSocket client {} for user: {}", client_id, user_id);
|
||||
|
||||
// Update connection stats
|
||||
mgr.stats
|
||||
.connections_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
mgr.stats
|
||||
.connections_active
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
// Update active users
|
||||
{
|
||||
let mut active_users = mgr.stats.active_users.write().await;
|
||||
let count = active_users.entry(user_id.clone()).or_insert(0);
|
||||
*count += 1;
|
||||
}
|
||||
|
||||
{
|
||||
let mut clients = mgr.clients.write().await;
|
||||
clients.insert(client_id, (user_id.clone(), tx));
|
||||
}
|
||||
|
||||
{
|
||||
let mut client_channels = mgr.client_channels.write().await;
|
||||
client_channels.insert(client_id, std::collections::HashSet::new());
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = rx.recv() => {
|
||||
if let Some(msg) = msg {
|
||||
if socket.send(Message::Text(msg)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
incoming = socket.recv() => {
|
||||
let msg = match incoming {
|
||||
Some(Ok(msg)) => msg,
|
||||
_ => break,
|
||||
};
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
if text.len() > max_size {
|
||||
warn!("⚠️ Message from client {} exceeds size limit: {} > {}", client_id, text.len(), max_size);
|
||||
let err_resp = serialize_message(&WSMessage {
|
||||
method: "error".to_string(),
|
||||
success: Some(false),
|
||||
error: Some("Message exceeds size limit".to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
if socket.send(Message::Text(err_resp)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
mgr.stats.messages_received_total.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let ws_msg: WSMessage = match serde_json::from_str(&text) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
warn!("⚠️ Invalid message format from client {}: {}", client_id, e);
|
||||
mgr.stats.errors_json_parse.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
mgr.stats.errors_total.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let err_resp = serialize_message(&WSMessage {
|
||||
method: "error".to_string(),
|
||||
success: Some(false),
|
||||
error: Some("Invalid message format. Review the schema and retry".to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
if socket.send(Message::Text(err_resp)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("📥 Received {} message from client {}", ws_msg.method, client_id);
|
||||
|
||||
match ws_msg.method.as_str() {
|
||||
"subscribe_graph_execution" => {
|
||||
let graph_exec_id = match &ws_msg.data {
|
||||
Some(Value::Object(map)) => map.get("graph_exec_id").and_then(|v| v.as_str()),
|
||||
_ => None,
|
||||
};
|
||||
let Some(graph_exec_id) = graph_exec_id else {
|
||||
warn!("⚠️ Missing graph_exec_id in subscribe_graph_execution from client {}", client_id);
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Missing graph_exec_id"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
};
|
||||
let channel = format!("{user_id}|graph_exec#{graph_exec_id}");
|
||||
debug!("📌 Client {} subscribing to channel: {}", client_id, channel);
|
||||
|
||||
{
|
||||
let mut subs = mgr.subscribers.write().await;
|
||||
subs.entry(channel.clone()).or_insert(std::collections::HashSet::new()).insert(client_id);
|
||||
}
|
||||
{
|
||||
let mut chs = mgr.client_channels.write().await;
|
||||
if let Some(set) = chs.get_mut(&client_id) {
|
||||
set.insert(channel.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Update subscription stats
|
||||
update_subscription_stats(&mgr, &channel, true).await;
|
||||
|
||||
let resp = WSMessage {
|
||||
method: "subscribe_graph_execution".to_string(),
|
||||
success: Some(true),
|
||||
channel: Some(channel),
|
||||
..Default::default()
|
||||
};
|
||||
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
"subscribe_graph_executions" => {
|
||||
let graph_id = match &ws_msg.data {
|
||||
Some(Value::Object(map)) => map.get("graph_id").and_then(|v| v.as_str()),
|
||||
_ => None,
|
||||
};
|
||||
let Some(graph_id) = graph_id else {
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Missing graph_id"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
};
|
||||
let channel = format!("{user_id}|graph#{graph_id}|executions");
|
||||
|
||||
{
|
||||
let mut subs = mgr.subscribers.write().await;
|
||||
subs.entry(channel.clone()).or_insert(std::collections::HashSet::new()).insert(client_id);
|
||||
}
|
||||
{
|
||||
let mut chs = mgr.client_channels.write().await;
|
||||
if let Some(set) = chs.get_mut(&client_id) {
|
||||
set.insert(channel.clone());
|
||||
}
|
||||
}
|
||||
debug!("📌 Client {} subscribing to channel: {}", client_id, channel);
|
||||
// Update subscription stats
|
||||
update_subscription_stats(&mgr, &channel, true).await;
|
||||
|
||||
let resp = WSMessage {
|
||||
method: "subscribe_graph_executions".to_string(),
|
||||
success: Some(true),
|
||||
channel: Some(channel),
|
||||
..Default::default()
|
||||
};
|
||||
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
"unsubscribe" => {
|
||||
let channel = match &ws_msg.data {
|
||||
Some(Value::String(s)) => Some(s.as_str()),
|
||||
Some(Value::Object(map)) => map.get("channel").and_then(|v| v.as_str()),
|
||||
_ => None,
|
||||
};
|
||||
let Some(channel) = channel else {
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Missing channel"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
};
|
||||
let channel = channel.to_string();
|
||||
|
||||
if !channel.starts_with(&format!("{user_id}|")) {
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Unauthorized channel"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
{
|
||||
let mut subs = mgr.subscribers.write().await;
|
||||
if let Some(set) = subs.get_mut(&channel) {
|
||||
set.remove(&client_id);
|
||||
if set.is_empty() {
|
||||
subs.remove(&channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut chs = mgr.client_channels.write().await;
|
||||
if let Some(set) = chs.get_mut(&client_id) {
|
||||
set.remove(&channel);
|
||||
}
|
||||
}
|
||||
|
||||
// Update subscription stats
|
||||
update_subscription_stats(&mgr, &channel, false).await;
|
||||
|
||||
let resp = WSMessage {
|
||||
method: "unsubscribe".to_string(),
|
||||
success: Some(true),
|
||||
channel: Some(channel),
|
||||
..Default::default()
|
||||
};
|
||||
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
"heartbeat" => {
|
||||
if ws_msg.data == Some(Value::String("ping".to_string())) {
|
||||
let resp = WSMessage {
|
||||
method: "heartbeat".to_string(),
|
||||
data: Some(Value::String("pong".to_string())),
|
||||
success: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Invalid heartbeat"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("❓ Unknown method '{}' from client {}", ws_msg.method, client_id);
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Unknown method"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
Message::Ping(_) => {
|
||||
if socket.send(Message::Pong(vec![])).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Message::Pong(_) => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
else => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
debug!("👋 WebSocket client {} disconnected, cleaning up", client_id);
|
||||
|
||||
// Update connection stats
|
||||
mgr.stats
|
||||
.connections_active
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
// Update active users
|
||||
{
|
||||
let mut active_users = mgr.stats.active_users.write().await;
|
||||
if let Some(count) = active_users.get_mut(&user_id) {
|
||||
*count = count.saturating_sub(1);
|
||||
if *count == 0 {
|
||||
active_users.remove(&user_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let channels = {
|
||||
let mut client_channels = mgr.client_channels.write().await;
|
||||
client_channels.remove(&client_id).unwrap_or_default()
|
||||
};
|
||||
|
||||
{
|
||||
let mut subs = mgr.subscribers.write().await;
|
||||
for channel in &channels {
|
||||
if let Some(set) = subs.get_mut(channel) {
|
||||
set.remove(&client_id);
|
||||
if set.is_empty() {
|
||||
subs.remove(channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update subscription stats for all channels the client was subscribed to
|
||||
for channel in &channels {
|
||||
update_subscription_stats(&mgr, channel, false).await;
|
||||
}
|
||||
|
||||
{
|
||||
let mut clients = mgr.clients.write().await;
|
||||
clients.remove(&client_id);
|
||||
}
|
||||
|
||||
debug!("✨ Cleanup completed for client {}", client_id);
|
||||
}
|
||||
26
autogpt_platform/autogpt-rs/websocket/src/lib.rs
Normal file
26
autogpt_platform/autogpt-rs/websocket/src/lib.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
#![deny(warnings)]
|
||||
#![deny(clippy::unwrap_used)]
|
||||
#![deny(clippy::panic)]
|
||||
#![deny(clippy::unimplemented)]
|
||||
#![deny(clippy::todo)]
|
||||
|
||||
|
||||
pub mod config;
|
||||
pub mod connection_manager;
|
||||
pub mod handlers;
|
||||
pub mod models;
|
||||
pub mod stats;
|
||||
|
||||
pub use config::Config;
|
||||
pub use connection_manager::ConnectionManager;
|
||||
pub use handlers::ws_handler;
|
||||
pub use stats::Stats;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub mgr: Arc<ConnectionManager>,
|
||||
pub config: Arc<Config>,
|
||||
pub stats: Arc<Stats>,
|
||||
}
|
||||
172
autogpt_platform/autogpt-rs/websocket/src/main.rs
Normal file
172
autogpt_platform/autogpt-rs/websocket/src/main.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{header, StatusCode},
|
||||
response::Response,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use clap::Parser;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing::{debug, error, info};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::connection_manager::ConnectionManager;
|
||||
use crate::handlers::ws_handler;
|
||||
|
||||
async fn stats_handler(
|
||||
axum::Extension(state): axum::Extension<AppState>,
|
||||
) -> Result<axum::response::Json<stats::StatsSnapshot>, StatusCode> {
|
||||
let snapshot = state.stats.snapshot().await;
|
||||
Ok(axum::response::Json(snapshot))
|
||||
}
|
||||
|
||||
async fn prometheus_handler(
|
||||
axum::Extension(state): axum::Extension<AppState>,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let snapshot = state.stats.snapshot().await;
|
||||
let prometheus_text = state.stats.to_prometheus_format(&snapshot);
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "text/plain; version=0.0.4")
|
||||
.body(Body::from(prometheus_text))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
mod config;
|
||||
mod connection_manager;
|
||||
mod handlers;
|
||||
mod models;
|
||||
mod stats;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about)]
|
||||
struct Cli {
|
||||
/// Path to a TOML configuration file
|
||||
#[arg(short = 'c', long = "config", value_name = "FILE")]
|
||||
config: Option<std::path::PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
mgr: Arc<ConnectionManager>,
|
||||
config: Arc<Config>,
|
||||
stats: Arc<stats::Stats>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Initialize tracing
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "websocket=info,tower_http=debug".into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
info!("🚀 Starting WebSocket API server");
|
||||
|
||||
let cli = Cli::parse();
|
||||
let config = Arc::new(Config::load(cli.config.as_deref()));
|
||||
info!(
|
||||
"⚙️ Configuration loaded - host: {}, port: {}, auth: {}",
|
||||
config.host, config.port, config.enable_auth
|
||||
);
|
||||
|
||||
let redis_client = match redis::Client::open(config.redis_url.clone()) {
|
||||
Ok(client) => {
|
||||
debug!("✅ Redis client created successfully");
|
||||
client
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"❌ Failed to create Redis client: {}. Please check REDIS_URL environment variable",
|
||||
e
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
let stats = Arc::new(stats::Stats::default());
|
||||
let mgr = Arc::new(ConnectionManager::new(
|
||||
redis_client,
|
||||
config.execution_event_bus_name.clone(),
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
let mgr_clone = mgr.clone();
|
||||
tokio::spawn(async move {
|
||||
debug!("📡 Starting event broadcaster task");
|
||||
mgr_clone.run_broadcaster().await;
|
||||
});
|
||||
|
||||
let state = AppState {
|
||||
mgr,
|
||||
config: config.clone(),
|
||||
stats,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/ws", get(ws_handler))
|
||||
.route("/stats", get(stats_handler))
|
||||
.route("/metrics", get(prometheus_handler))
|
||||
.layer(axum::Extension(state));
|
||||
|
||||
let cors = if config.backend_cors_allow_origins.is_empty() {
|
||||
// If no specific origins configured, allow any origin but without credentials
|
||||
CorsLayer::new()
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
.allow_origin(Any)
|
||||
} else {
|
||||
// If specific origins configured, allow credentials
|
||||
CorsLayer::new()
|
||||
.allow_methods([
|
||||
axum::http::Method::GET,
|
||||
axum::http::Method::POST,
|
||||
axum::http::Method::PUT,
|
||||
axum::http::Method::DELETE,
|
||||
axum::http::Method::OPTIONS,
|
||||
])
|
||||
.allow_headers(vec![
|
||||
axum::http::header::CONTENT_TYPE,
|
||||
axum::http::header::AUTHORIZATION,
|
||||
])
|
||||
.allow_credentials(true)
|
||||
.allow_origin(
|
||||
config
|
||||
.backend_cors_allow_origins
|
||||
.iter()
|
||||
.filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
};
|
||||
|
||||
let app = app.layer(cors);
|
||||
|
||||
let addr = format!("{}:{}", config.host, config.port);
|
||||
let listener = match TcpListener::bind(&addr).await {
|
||||
Ok(listener) => {
|
||||
info!("🎧 WebSocket server listening on: {}", addr);
|
||||
listener
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"❌ Failed to bind to {}: {}. Please check if the port is already in use",
|
||||
addr, e
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
info!("✨ WebSocket API server ready to accept connections");
|
||||
|
||||
if let Err(e) = axum::serve(listener, app.into_make_service()).await {
|
||||
error!("💥 Server error: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
103
autogpt_platform/autogpt-rs/websocket/src/models.rs
Normal file
103
autogpt_platform/autogpt-rs/websocket/src/models.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct WSMessage {
|
||||
pub method: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub success: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub channel: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
}
|
||||
|
||||
// Event models moved from events.rs
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "event_type")]
|
||||
pub enum ExecutionEvent {
|
||||
#[serde(rename = "graph_execution_update")]
|
||||
GraphExecutionUpdate(GraphExecutionEvent),
|
||||
#[serde(rename = "node_execution_update")]
|
||||
NodeExecutionUpdate(NodeExecutionEvent),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphExecutionEvent {
|
||||
pub id: String,
|
||||
pub graph_id: String,
|
||||
pub graph_version: u32,
|
||||
pub user_id: String,
|
||||
pub status: ExecutionStatus,
|
||||
pub started_at: Option<String>,
|
||||
pub ended_at: Option<String>,
|
||||
pub preset_id: Option<String>,
|
||||
pub stats: Option<ExecutionStats>,
|
||||
|
||||
// Keep these as JSON since they vary by graph
|
||||
pub inputs: Value,
|
||||
pub outputs: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeExecutionEvent {
|
||||
pub node_exec_id: String,
|
||||
pub node_id: String,
|
||||
pub graph_exec_id: String,
|
||||
pub graph_id: String,
|
||||
pub graph_version: u32,
|
||||
pub user_id: String,
|
||||
pub block_id: String,
|
||||
pub status: ExecutionStatus,
|
||||
pub add_time: String,
|
||||
pub queue_time: Option<String>,
|
||||
pub start_time: Option<String>,
|
||||
pub end_time: Option<String>,
|
||||
|
||||
// Keep these as JSON since they vary by node type
|
||||
pub input_data: Value,
|
||||
pub output_data: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutionStats {
|
||||
pub cost: f64,
|
||||
pub duration: f64,
|
||||
pub duration_cpu_only: f64,
|
||||
pub error: Option<String>,
|
||||
pub node_error_count: u32,
|
||||
pub node_exec_count: u32,
|
||||
pub node_exec_time: f64,
|
||||
pub node_exec_time_cpu_only: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum ExecutionStatus {
|
||||
Queued,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
Incomplete,
|
||||
Terminated,
|
||||
}
|
||||
|
||||
// Wrapper for the Redis event that includes the payload
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RedisEventWrapper {
|
||||
pub payload: ExecutionEvent,
|
||||
}
|
||||
|
||||
impl RedisEventWrapper {
|
||||
pub fn parse(json_str: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(json_str)
|
||||
}
|
||||
}
|
||||
238
autogpt_platform/autogpt-rs/websocket/src/stats.rs
Normal file
238
autogpt_platform/autogpt-rs/websocket/src/stats.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Stats {
|
||||
// Connection metrics
|
||||
pub connections_total: AtomicU64,
|
||||
pub connections_active: AtomicU64,
|
||||
pub connections_failed_auth: AtomicU64,
|
||||
|
||||
// Message metrics
|
||||
pub messages_received_total: AtomicU64,
|
||||
pub messages_sent_total: AtomicU64,
|
||||
pub messages_failed_total: AtomicU64,
|
||||
|
||||
// Subscription metrics
|
||||
pub subscriptions_total: AtomicU64,
|
||||
pub subscriptions_active: AtomicU64,
|
||||
pub unsubscriptions_total: AtomicU64,
|
||||
|
||||
// Event metrics by type
|
||||
pub events_received_total: AtomicU64,
|
||||
pub graph_execution_events: AtomicU64,
|
||||
pub node_execution_events: AtomicU64,
|
||||
|
||||
// Redis metrics
|
||||
pub redis_messages_received: AtomicU64,
|
||||
pub redis_messages_ignored: AtomicU64,
|
||||
|
||||
// Channel metrics
|
||||
pub channels_active: RwLock<HashMap<String, usize>>, // channel -> subscriber count
|
||||
|
||||
// User metrics
|
||||
pub active_users: RwLock<HashMap<String, usize>>, // user_id -> connection count
|
||||
|
||||
// Error metrics
|
||||
pub errors_total: AtomicU64,
|
||||
pub errors_json_parse: AtomicU64,
|
||||
pub errors_message_size: AtomicU64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct StatsSnapshot {
|
||||
// Connection metrics
|
||||
pub connections_total: u64,
|
||||
pub connections_active: u64,
|
||||
pub connections_failed_auth: u64,
|
||||
|
||||
// Message metrics
|
||||
pub messages_received_total: u64,
|
||||
pub messages_sent_total: u64,
|
||||
pub messages_failed_total: u64,
|
||||
|
||||
// Subscription metrics
|
||||
pub subscriptions_total: u64,
|
||||
pub subscriptions_active: u64,
|
||||
pub unsubscriptions_total: u64,
|
||||
|
||||
// Event metrics
|
||||
pub events_received_total: u64,
|
||||
pub graph_execution_events: u64,
|
||||
pub node_execution_events: u64,
|
||||
|
||||
// Redis metrics
|
||||
pub redis_messages_received: u64,
|
||||
pub redis_messages_ignored: u64,
|
||||
|
||||
// Channel metrics
|
||||
pub channels_active_count: usize,
|
||||
pub total_subscribers: usize,
|
||||
|
||||
// User metrics
|
||||
pub active_users_count: usize,
|
||||
|
||||
// Error metrics
|
||||
pub errors_total: u64,
|
||||
pub errors_json_parse: u64,
|
||||
pub errors_message_size: u64,
|
||||
}
|
||||
|
||||
impl Stats {
|
||||
pub async fn snapshot(&self) -> StatsSnapshot {
|
||||
// Take read locks for HashMap data - it's ok if this is slightly stale
|
||||
let channels = self.channels_active.read().await;
|
||||
let total_subscribers: usize = channels.values().sum();
|
||||
let channels_active_count = channels.len();
|
||||
drop(channels); // Release lock early
|
||||
|
||||
let users = self.active_users.read().await;
|
||||
let active_users_count = users.len();
|
||||
drop(users); // Release lock early
|
||||
|
||||
StatsSnapshot {
|
||||
connections_total: self.connections_total.load(Ordering::Relaxed),
|
||||
connections_active: self.connections_active.load(Ordering::Relaxed),
|
||||
connections_failed_auth: self.connections_failed_auth.load(Ordering::Relaxed),
|
||||
|
||||
messages_received_total: self.messages_received_total.load(Ordering::Relaxed),
|
||||
messages_sent_total: self.messages_sent_total.load(Ordering::Relaxed),
|
||||
messages_failed_total: self.messages_failed_total.load(Ordering::Relaxed),
|
||||
|
||||
subscriptions_total: self.subscriptions_total.load(Ordering::Relaxed),
|
||||
subscriptions_active: self.subscriptions_active.load(Ordering::Relaxed),
|
||||
unsubscriptions_total: self.unsubscriptions_total.load(Ordering::Relaxed),
|
||||
|
||||
events_received_total: self.events_received_total.load(Ordering::Relaxed),
|
||||
graph_execution_events: self.graph_execution_events.load(Ordering::Relaxed),
|
||||
node_execution_events: self.node_execution_events.load(Ordering::Relaxed),
|
||||
|
||||
redis_messages_received: self.redis_messages_received.load(Ordering::Relaxed),
|
||||
redis_messages_ignored: self.redis_messages_ignored.load(Ordering::Relaxed),
|
||||
|
||||
channels_active_count,
|
||||
total_subscribers,
|
||||
active_users_count,
|
||||
|
||||
errors_total: self.errors_total.load(Ordering::Relaxed),
|
||||
errors_json_parse: self.errors_json_parse.load(Ordering::Relaxed),
|
||||
errors_message_size: self.errors_message_size.load(Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_prometheus_format(&self, snapshot: &StatsSnapshot) -> String {
|
||||
let mut output = String::new();
|
||||
|
||||
// Connection metrics
|
||||
output.push_str("# HELP ws_connections_total Total number of WebSocket connections\n");
|
||||
output.push_str("# TYPE ws_connections_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_connections_total {}\n\n",
|
||||
snapshot.connections_total
|
||||
));
|
||||
|
||||
output.push_str(
|
||||
"# HELP ws_connections_active Current number of active WebSocket connections\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_connections_active gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_connections_active {}\n\n",
|
||||
snapshot.connections_active
|
||||
));
|
||||
|
||||
output
|
||||
.push_str("# HELP ws_connections_failed_auth Total number of failed authentications\n");
|
||||
output.push_str("# TYPE ws_connections_failed_auth counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_connections_failed_auth {}\n\n",
|
||||
snapshot.connections_failed_auth
|
||||
));
|
||||
|
||||
// Message metrics
|
||||
output.push_str(
|
||||
"# HELP ws_messages_received_total Total number of messages received from clients\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_messages_received_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_messages_received_total {}\n\n",
|
||||
snapshot.messages_received_total
|
||||
));
|
||||
|
||||
output.push_str("# HELP ws_messages_sent_total Total number of messages sent to clients\n");
|
||||
output.push_str("# TYPE ws_messages_sent_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_messages_sent_total {}\n\n",
|
||||
snapshot.messages_sent_total
|
||||
));
|
||||
|
||||
// Subscription metrics
|
||||
output.push_str("# HELP ws_subscriptions_active Current number of active subscriptions\n");
|
||||
output.push_str("# TYPE ws_subscriptions_active gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_subscriptions_active {}\n\n",
|
||||
snapshot.subscriptions_active
|
||||
));
|
||||
|
||||
// Event metrics
|
||||
output.push_str(
|
||||
"# HELP ws_events_received_total Total number of events received from Redis\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_events_received_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_events_received_total {}\n\n",
|
||||
snapshot.events_received_total
|
||||
));
|
||||
|
||||
output.push_str(
|
||||
"# HELP ws_graph_execution_events_total Total number of graph execution events\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_graph_execution_events_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_graph_execution_events_total {}\n\n",
|
||||
snapshot.graph_execution_events
|
||||
));
|
||||
|
||||
output.push_str(
|
||||
"# HELP ws_node_execution_events_total Total number of node execution events\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_node_execution_events_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_node_execution_events_total {}\n\n",
|
||||
snapshot.node_execution_events
|
||||
));
|
||||
|
||||
// Channel metrics
|
||||
output.push_str("# HELP ws_channels_active Number of active channels\n");
|
||||
output.push_str("# TYPE ws_channels_active gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_channels_active {}\n\n",
|
||||
snapshot.channels_active_count
|
||||
));
|
||||
|
||||
output.push_str(
|
||||
"# HELP ws_total_subscribers Total number of subscribers across all channels\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_total_subscribers gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_total_subscribers {}\n\n",
|
||||
snapshot.total_subscribers
|
||||
));
|
||||
|
||||
// User metrics
|
||||
output.push_str("# HELP ws_active_users Number of unique users with active connections\n");
|
||||
output.push_str("# TYPE ws_active_users gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_active_users {}\n\n",
|
||||
snapshot.active_users_count
|
||||
));
|
||||
|
||||
// Error metrics
|
||||
output.push_str("# HELP ws_errors_total Total number of errors\n");
|
||||
output.push_str("# TYPE ws_errors_total counter\n");
|
||||
output.push_str(&format!("ws_errors_total {}\n", snapshot.errors_total));
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user