Compare commits

..

11 Commits

Author SHA1 Message Date
Reinier van der Leer
bef6e08f3d fix regex hanging 2025-09-23 23:11:33 +02:00
Reinier van der Leer
b8576a9b78 minimize changes to poetry files 2025-09-23 16:47:48 +02:00
Reinier van der Leer
2796fde620 Merge branch 'dev' into pwuts/open-2714-implement-lenient-json-parsing-in-ai-structured-response 2025-09-23 16:42:34 +02:00
Reinier van der Leer
e7b15c39aa Make JSON output optional in LLM calls 2025-09-23 16:40:38 +02:00
Zamil Majdy
5f5ee9dab0 feat(backend): implement comprehensive load testing performance fixes + database health improvements (#10965)
## Summary

This PR implements comprehensive performance fixes for the AutoGPT
Platform, addressing critical load testing bottlenecks and database
connectivity issues:

### 🚀 Load Testing Infrastructure & Performance Fixes

**Problem:** Platform failed under 100+ RPS load due to database
bottlenecks and inefficient API patterns.

**Root Cause Analysis:**
- Database connection pool exhaustion (50 connections/pod limit)
- Unnecessary `/api/auth/user` calls causing repeated database lookups
- Large API payloads (4.6MB for `/api/graphs`, 3.9MB for
`/api/executions`)
- Missing user profile caching led to O(n) database queries per request

**Solutions Implemented:**

#### 1. User Profile Caching System
- **Implementation:** Use existing `async_ttl_cache` from autogpt_libs
for user lookup optimization
- **Applied to:** `get_or_create_user`, `get_user_by_id`,
`get_user_by_email` in `/backend/backend/data/user.py`
- **Performance impact:** Eliminates repeated database lookups, reduces
connection pool pressure
- **Cache invalidation:** Automatic cleanup on user updates maintains
data consistency

#### 2. Load Test Infrastructure Overhaul
- **k6 Load Testing Suite:** Production-ready tests with Grafana Cloud
integration
- **Rate Limit Optimization:** Configure for 5 VUs with high
`REQUESTS_PER_VU` to avoid Supabase limits
- **Realistic User Workflows:** Replace API hammering with actual user
journey patterns
- **Comprehensive Coverage:** Basic connectivity, core APIs, graph
execution, platform integration

#### 3. Interactive Load Testing Tools

**Interactive CLI (`interactive-test.js`):**
- Guided test selection with descriptions and recommendations
- Environment targeting (Local, Dev, Production) 
- Parameter validation and k6 cloud integration
- User-friendly interface for non-technical users

**Enhanced Single Endpoint Testing (`single-endpoint-test.js`):**
- Support for up to 500 concurrent requests per VU using `http.batch()`
- Individual endpoint debugging (credits, graphs, blocks, executions)
- Burst load testing capabilities for RPS validation
- Performance isolation for specific API bottlenecks

#### 4. Sub-Agent Approval Automation
- **Auto-approve sub-agents** when main agent is approved for seamless
store workflow
- **Transaction safety** with atomic operations via database
transactions
- **Parallel processing** using asyncio.gather for performance
- **Hidden from store** with isAvailable=false for sub-agents

**Performance Results:**
-  **API Payload Optimization:** /api/graphs reduced from 4.6MB → 149KB
(97% smaller, 20x faster)
-  **Load Testing Success:** Sustained 100+ RPS with 100% success rate
on k6 cloud
-  **Database Efficiency:** User profile caching eliminates repeated
lookups
-  **Rate Limit Resolution:** No more 429 errors with optimized VU
configuration

### 🔧 Database Health Check & Retry Improvements

**Problem:** Execution ID `230a8036-9ba7-47c3-8f01-40bf21a9ff42` failed
due to database connectivity issues.

**Database Manager Health Check Improvements:**
- **Replace ineffective health check** that only checked
`db.is_connected()`
- **Add actual database query test** using
`db.query_raw_with_schema("SELECT 1 as health_check")`
- **Detect Prisma query engine failures** that cause HTTP 500 errors
during execution
- **Detailed error messaging** for better observability and debugging

**Service Communication Retry Enhancements:**
- **Increase retry attempts** from 5 to 8 for better resilience  
- **Add configurable max wait time** (`pyro_client_max_wait`, default
15s vs 5s hardcoded)
- **Longer retry intervals** help handle intermittent Supabase
connectivity issues

## Impact

### Load Testing & Performance
-  **Eliminates database bottlenecks** preventing 100+ RPS load
capacity
-  **User profile caching** reduces database connection pool pressure  
-  **Realistic load testing** validates actual platform capacity vs
infrastructure limits
-  **Interactive tools** enable easy performance debugging and
validation
-  **Sub-agent automation** streamlines store approval workflows

### Database Reliability  
-  **Prevents execution engine failures** from database connectivity
issues
-  **Kubernetes will restart unhealthy pods** when real database
problems occur
-  **Better handling of intermittent Supabase service degradation**
-  **More resilient service-to-service communication** across the
platform

## Test Results

**Load Testing Validation (k6 Cloud Project ID: 4254406):**
- Basic Connectivity: 560+ RPS sustained, 100% success rate
- Core API: 500 concurrent requests, 100% success rate, full 7-minute
duration
- Graph Execution: 100 concurrent operations, successful under load
- Comprehensive Platform: End-to-end user workflows, 100% completion
rate

**Database Health Check:**
- [x] Health check properly fails when database queries fail
- [x] Health check passes when database is working correctly  
- [x] Retry configuration applied to database manager calls
- [x] Code formatting and linting pass

## Files Modified

**Load Testing Infrastructure:**
- `load-tests/interactive-test.js` - Interactive CLI for guided test
execution
- `load-tests/single-endpoint-test.js` - Enhanced single endpoint
testing with high concurrency
- `load-tests/README.md` - Comprehensive documentation with usage
examples
- Multiple existing test files enhanced for k6 cloud compatibility

**Performance Optimizations:**
- `backend/backend/data/user.py` - User profile caching implementation
- `backend/backend/server/routers/v1.py` - API endpoint improvements  
- `backend/backend/data/store.py` - Sub-agent approval automation

**Database & Health Checks:**
- `backend/backend/data/db.py` - Enhanced database health check with
actual query testing
- `backend/backend/util/service.py` - Improved retry configuration for
service communication

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-09-23 10:49:09 +00:00
Nicholas Tindle
35325a5127 fix(backend): make preset migration not crash the system (#10966)
<!-- Clearly explain the need for these changes: -->
For those who develop blocks, they may or may not exist in the code at
the same time as the database.
> Create block in one branch, test, then move to another branch the
block is not in

This migration will prevent startup in that case.

### Changes 🏗️
Adds a try except around the migration
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Test that startup actually works

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2025-09-22 20:33:57 +00:00
Zamil Majdy
be72cc6d19 fix(load-tests): resolve k6 VU crashes and authentication distribution issues (#10962)
## Summary
Fix critical k6 load testing issues that were causing VU crashes and
preventing proper high-throughput testing. This enables reliable 100+
RPS load testing for the AutoGPT Platform.

## Root Cause Analysis
- **VU Crashes**: All VUs were trying to authenticate with the same user
(loadtest1@example.com), causing auth failures and VU crashes with
`throw Error()`
- **k6 Cloud Aborts**: `maxDuration` field not supported in k6 cloud,
causing immediate test aborts
- **Low RPS**: Missing `REQUESTS_PER_VU` parameter caused graph tests to
achieve only 4 RPS instead of 100+ RPS

## Changes Made

### Core Authentication Fixes
- **Round-robin user assignment**: Fixed user distribution logic in
`utils/auth.js`
  ```javascript
  // Before: All VUs used loadtest1@example.com
const assignedUserIndex = (vuId - 1) % users.length; // Now: Round-robin
across 3 users
  ```
- **Graceful error handling**: Changed from `throw Error()` to `return
null` to prevent VU crashes
- **Null authentication checks**: Added proper handling in all test
scripts to gracefully skip iterations instead of crashing

### k6 Cloud Compatibility  
- **Remove unsupported maxDuration**: Eliminated from all test scripts
(basic-connectivity, core-api, graph-execution)
- **Enhanced cloud configuration**: Proper project ID and timeout
settings for reliable cloud execution

### Performance & Concurrency
- **REQUESTS_PER_VU support**: All tests now properly support concurrent
operations parameter
- **Concurrent graph operations**: Graph test now supports `VUS=5
REQUESTS_PER_VU=20` for 100+ concurrent operations
- **Proper load distribution**: Authentication load distributed across 3
test users instead of overwhelming single user

## Test Results

### Before Fix
 VU crashes: "missing 1 required keyword-only argument: user_context"  
 k6 cloud aborts: "maxDuration not supported"  
 Low RPS: Graph test achieving only 4 RPS  
 Auth failures: All VUs fighting over same user  

### After Fix 
- **100% success rate** across all test types
- **400 graph creations** at 4.60/s sustained throughput  
- **100 graph executions** at 1.15/s
- **0% HTTP failures** (0 out of 505 requests)
- **P95 latency**: 8.49s (well under 45s threshold)
- **Stable VUs**: No crashes, graceful auth failure handling

## k6 Cloud Results
- Basic Connectivity:
https://significantgravitas.grafana.net/a/k6-app/runs/5591228
- All tests running successfully with proper authentication distribution
- Achieved target 100+ RPS with concurrent operations

## Files Modified
- `utils/auth.js` - Fixed user assignment and error handling
- `basic-connectivity-test.js` - Added null auth checks, removed
maxDuration
- `core-api-load-test.js` - Added null auth checks, removed maxDuration,
added REQUESTS_PER_VU support
- `graph-execution-load-test.js` - Added null auth checks, removed
maxDuration, enhanced concurrency

## Usage Examples
```bash
# Basic connectivity at 100 RPS
K6_ENVIRONMENT=DEV VUS=10 DURATION=1m k6 run basic-connectivity-test.js

# Graph operations at 100+ RPS  
K6_ENVIRONMENT=DEV VUS=5 DURATION=1m REQUESTS_PER_VU=20 k6 run graph-execution-load-test.js

# k6 Cloud execution
K6_CLOUD_TOKEN=xxx K6_CLOUD_PROJECT_ID=xxx VUS=10 DURATION=30s k6 run basic-connectivity-test.js --out cloud
```

## Impact
- **Prevents VU crashes**: Tests remain stable under high concurrency
- **Enables k6 cloud**: All tests compatible with k6 cloud
infrastructure
- **Achieves 100+ RPS**: Proper concurrent operations support
- **Better observability**: Clear logging of authentication assignment
and failures
- **Production ready**: Reliable load testing for performance validation

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-09-22 23:01:41 +07:00
Zamil Majdy
e881c5d2f4 fix(load-tests): resolve k6 VU crashes and authentication distribution issues
## Problem
k6 load tests were experiencing VU crashes causing false failure rates:
- VUs crashed when authentication failed, skewing success metrics
- All VUs tried same user (loadtest1) first, causing auth conflicts
- 48% success rate was due to test methodology, not server issues

## Root Cause Analysis
1. **Poor user distribution**: All VUs attempted auth with user[0] first
2. **Unhandled auth failures**: throw Error() crashed entire VUs
3. **Concurrent auth conflicts**: Multiple VUs hitting same credentials

## Changes Made

### Fix Authentication Distribution (utils/auth.js)
- **Round-robin user assignment**: VU 1,4,7→user1, VU 2,5,8→user2, VU 3,6,9→user3
- **Fallback logic**: Try assigned user first, then others if needed
- **Graceful failure handling**: Return null instead of throwing errors

### Fix VU Crash Handling (basic-connectivity-test.js)
- **Null auth check**: Skip iteration gracefully when auth fails
- **Prevent VU crashes**: Continue test execution without crashing VU
- **Proper error tracking**: Log auth failures without breaking test flow

## Results
- **Before**: 48% success rate, 4/10 VUs crashed, P95 >8s
- **After**: 100% success rate, 10/10 VUs stable, P95 333ms
- **System capability**: Can handle 100+ RPS (previous issues were test bugs)

## Test Evidence
- Fixed authentication conflicts with 3 users supporting 10 VUs
- All VUs remain stable throughout test duration
- Zero server-side errors during previously 'failed' tests

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-22 15:46:19 +07:00
Zamil Majdy
bb20821634 feat(backend): Add k6 load testing infrastructure + fix critical performance issues (#10941)
# AutoGPT Platform Load Testing Infrastructure

A comprehensive k6-based load testing suite for AutoGPT Platform API
with Grafana Cloud integration for real-time monitoring and performance
analysis.

## 🚀 Quick Start

### Prerequisites
- k6 installed ([Install
Guide](https://k6.io/docs/getting-started/installation/))
- Backend server running (port 8006)
- Valid test user credentials

### Running Tests

#### 1. Setup Test Users (First Time Only)
```bash
cd autogpt_platform/backend/load-tests
k6 run setup-test-users.js
```

#### 2. Basic Load Tests
```bash
# Test API connectivity and authentication
k6 run basic-connectivity-test.js

# Test core API endpoints (credits, profiles)
k6 run core-api-load-test.js

# Test graph operations (create, execute)
k6 run graph-execution-load-test.js

# Full platform integration test
k6 run scenarios/comprehensive-platform-load-test.js
```

#### 3. Run with Grafana Cloud (Optional)
```bash
# Set environment variables
export K6_CLOUD_TOKEN="your-grafana-cloud-token"
export K6_CLOUD_PROJECT_ID="your-project-id"

# Run with cloud monitoring
k6 run basic-connectivity-test.js --out cloud
```

## 📊 Test Scenarios

| Test | Purpose | Endpoints Tested | Load Pattern |
|------|---------|-----------------|-------------|
| **Basic Connectivity** | Validate infrastructure | Auth, health checks
| 1-10 VUs, 10s-5m |
| **Core API** | Test CRUD operations | /api/credits, /api/auth/user |
1-5 VUs, 30s-2m |
| **Graph Execution** | Test graph workflows | /api/graphs,
/api/graphs/*/execute | 1-3 VUs, 1-3m |
| **Comprehensive** | End-to-end user journeys | All major endpoints |
1-2 VUs, 2-5m |

## 🔧 Configuration

### Environment Variables
```bash
# Target Environment
export K6_ENVIRONMENT="dev"    # dev, local, staging

# Load Test Parameters  
export VUS="5"                 # Virtual users (concurrent)
export DURATION="2m"           # Test duration
export REQUESTS_PER_VU="10"    # Requests per user

# Grafana Cloud (Optional)
export K6_CLOUD_TOKEN="your-token"
export K6_CLOUD_PROJECT_ID="your-project-id"
```

### Test Environments
- **LOCAL**: localhost:8006 (development)
- **DEV**: dev-server.agpt.co (staging)

## 📈 Performance Thresholds

Current SLA targets:
- **Response Time P95**: < 2 seconds
- **Error Rate**: < 5%
- **Authentication Success**: > 95%
- **Graph Creation**: < 5 seconds
- **Graph Execution**: < 30 seconds

## 🔍 Current Performance Issues Identified

⚠️ **Load testing reveals significant performance bottlenecks that need
optimization:**

### 📊 **Load Test Results**
| Endpoint | RPS | P95 Latency | Success Rate | Status |
|----------|-----|-------------|--------------|---------|
| Basic Connectivity | 40.6 | 926ms | 99.15% |  |
| Core API | 4.6 | 24.2s | 99.83% | ⚠️ |
| Graph Execution | 1.1 | 47.8s | 70.28% |  |
| Comprehensive Platform | 0.3 | 44.2s | 96.25% |  |

### 🚨 **Critical Issues Requiring Performance Work**
1. **Graph Operations**: 70% failure rate under load, P95 latency 47.8s
2. **Database Bottlenecks**: Transaction timeouts during concurrent
operations
3. **Query Optimization**: Graph creation involves multiple large
database operations
4. **Connection Pooling**: Database connection limits under high
concurrency

###  **Configuration Fixes Applied**
- **Database Transaction Timeout**: Increased from 15s to 30s (bandaid
solution)
- **Block Execution API**: Fixed missing user_context parameter  
- **Credits API Error Handling**: Added proper exception handling
- **CI Tests**: Fixed test_execute_graph_block

**Note**: These are configuration fixes, not performance optimizations.
The underlying performance issues still need to be addressed through
query optimization, database tuning, and application-level improvements.

## 🛠️ Infrastructure Features

- **k6 Load Testing**: JavaScript-based scenarios with realistic user
workflows
- **Grafana Cloud Integration**: Real-time dashboards and alerting
- **Multi-Environment Support**: Dev, local, staging configurations
- **Authentication Testing**: Supabase JWT token validation
- **Performance Monitoring**: SLA validation with configurable
thresholds
- **Automated User Setup**: Test user creation and management

## 📁 Files Structure

```
load-tests/
├── basic-connectivity-test.js          # Infrastructure validation
├── core-api-load-test.js               # Core API testing  
├── graph-execution-load-test.js        # Graph operations
├── setup-test-users.js                 # User management
├── scenarios/
│   └── comprehensive-platform-load-test.js  # End-to-end testing
├── configs/
│   ├── environment.js                  # Environment settings
│   └── grafana-cloud.js               # Monitoring configuration
└── utils/
    └── auth.js                        # Authentication utilities
```

## 🎯 Next Steps for Performance Optimization

1. **Query Optimization**: Profile and optimize graph creation queries
2. **Database Tuning**: Optimize connection pooling and indexing
3. **Caching Strategy**: Implement appropriate caching for frequently
accessed data
4. **Load Balancing**: Fix uneven traffic distribution between pods
5. **Monitoring**: Use this load testing infrastructure to measure
improvements

##  Test Plan
- [x] All load testing scenarios validated locally
- [x] Grafana Cloud integration working
- [x] Test user setup automated
- [x] Performance baselines established
- [x] Critical performance bottlenecks identified
- [x] CI tests passing (test_execute_graph_block fixed)
- [x] Configuration issues resolved
- [ ] **Performance optimizations still needed** (separate work)

**This PR provides the infrastructure to identify and monitor
performance issues. The actual performance optimizations are separate
work that should be prioritized based on these findings.**

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-09-22 08:28:57 +07:00
Reinier van der Leer
30402979bb add tests 2025-09-21 20:22:23 +02:00
Reinier van der Leer
62924a76d1 feat(blocks): Implement lenient JSON parsing in AI Structured Response block 2025-09-21 19:49:57 +02:00
474 changed files with 9640 additions and 29021 deletions

View File

@@ -12,7 +12,6 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
- **Infrastructure** - Docker configurations, CI/CD, and development tools
**Primary Languages & Frameworks:**
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
@@ -24,17 +23,15 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
**Always run these commands in the correct directory and in this order:**
1. **Initial Setup** (required once):
```bash
# Clone and enter repository
git clone <repo> && cd AutoGPT
# Start all services (database, redis, rabbitmq, clamav)
cd autogpt_platform && docker compose --profile local up deps --build --detach
```
2. **Backend Setup** (always run before backend development):
```bash
cd autogpt_platform/backend
poetry install # Install dependencies
@@ -51,7 +48,6 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
### Runtime Requirements
**Critical:** Always ensure Docker services are running before starting development:
```bash
cd autogpt_platform && docker compose --profile local up deps --build --detach
```
@@ -62,7 +58,6 @@ cd autogpt_platform && docker compose --profile local up deps --build --detach
### Development Commands
**Backend Development:**
```bash
cd autogpt_platform/backend
poetry run serve # Start development server (port 8000)
@@ -73,7 +68,6 @@ poetry run lint # Lint code (ruff) - run after format
```
**Frontend Development:**
```bash
cd autogpt_platform/frontend
pnpm dev # Start development server (port 3000) - use for active development
@@ -87,27 +81,23 @@ pnpm storybook # Start component development server
### Testing Strategy
**Backend Tests:**
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
**Frontend Tests:**
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
- **Component Tests**: Use Storybook for isolated component development
### Critical Validation Steps
**Before committing changes:**
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
2. Ensure all tests pass in modified areas
3. Verify Docker services are still running
4. Check that database migrations apply cleanly
**Common Issues & Workarounds:**
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
- **Permission errors**: Ensure Docker has proper permissions
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
@@ -118,7 +108,6 @@ pnpm storybook # Start component development server
### Core Architecture
**AutoGPT Platform** (`autogpt_platform/`):
- `backend/` - FastAPI server with async support
- `backend/backend/` - Core API logic
- `backend/blocks/` - Agent execution blocks
@@ -132,7 +121,6 @@ pnpm storybook # Start component development server
- `docker-compose.yml` - Development stack orchestration
**Key Configuration Files:**
- `pyproject.toml` - Python dependencies and tooling
- `package.json` - Node.js dependencies and scripts
- `schema.prisma` - Database schema and migrations
@@ -148,7 +136,6 @@ pnpm storybook # Start component development server
### Development Workflow
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
- `platform-backend-ci.yml` - Backend testing and validation
- `platform-frontend-ci.yml` - Frontend testing and validation
- `platform-fullstack-ci.yml` - End-to-end integration tests
@@ -159,13 +146,11 @@ pnpm storybook # Start component development server
### Key Source Files
**Backend Entry Points:**
- `backend/backend/server/server.py` - FastAPI application setup
- `backend/backend/data/` - Database models and user management
- `backend/blocks/` - Agent execution blocks and logic
**Frontend Entry Points:**
- `frontend/src/app/layout.tsx` - Root application layout
- `frontend/src/app/page.tsx` - Home page
- `frontend/src/lib/supabase/` - Authentication and database client
@@ -175,7 +160,6 @@ pnpm storybook # Start component development server
### Agent Block System
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
- Block definition with input/output schemas
- Execution logic with proper error handling
- Tests validating functionality
@@ -183,7 +167,6 @@ Agents are built using a visual block-based system where each block performs a s
### Database & ORM
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
- Schema in `schema.prisma`
- Migrations in `backend/migrations/`
- Always run `prisma migrate dev` and `prisma generate` after schema changes
@@ -191,15 +174,13 @@ Agents are built using a visual block-based system where each block performs a s
## Environment Configuration
### Configuration Files Priority Order
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
4. Docker Compose `environment:` sections override file-based config
5. Shell environment variables have highest precedence
### Docker Environment Setup
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
@@ -208,7 +189,6 @@ Agents are built using a visual block-based system where each block performs a s
## Advanced Development Patterns
### Adding New Blocks
1. Create file in `/backend/backend/blocks/`
2. Inherit from `Block` base class with input/output schemas
3. Implement `run` method with proper error handling
@@ -218,7 +198,6 @@ Agents are built using a visual block-based system where each block performs a s
7. Consider how inputs/outputs connect with other blocks in graph editor
### API Development
1. Update routes in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside route files
@@ -226,76 +205,21 @@ Agents are built using a visual block-based system where each block performs a s
5. Run `poetry run test` to verify changes
### Frontend Development
**📖 Complete Frontend Guide**: See `autogpt_platform/frontend/CONTRIBUTING.md` and `autogpt_platform/frontend/.cursorrules` for comprehensive patterns and conventions.
**Quick Reference:**
**Component Structure:**
- Separate render logic from data/behavior
- Structure: `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
- Exception: Small components (3-4 lines of logic) can be inline
- Render-only components can be direct files without folders
**Data Fetching:**
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Generated via Orval from backend OpenAPI spec
- Pattern: `use{Method}{Version}{OperationName}`
- Example: `useGetV2ListLibraryAgents`
- Regenerate with: `pnpm generate:api`
- **Never** use deprecated `BackendAPI` or `src/lib/autogpt-server-api/*`
**Code Conventions:**
- Use function declarations for components and handlers (not arrow functions)
- Only arrow functions for small inline lambdas (map, filter, etc.)
- Components: `PascalCase`, Hooks: `camelCase` with `use` prefix
- No barrel files or `index.ts` re-exports
- Minimal comments (code should be self-documenting)
**Styling:**
- Use Tailwind CSS utilities only
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Never use `src/components/__legacy__/*`
- Only use Phosphor Icons (`@phosphor-icons/react`)
- Prefer design tokens over hardcoded values
**Error Handling:**
- Render errors: Use `<ErrorCard />` component
- Mutation errors: Display with toast notifications
- Manual exceptions: Use `Sentry.captureException()`
- Global error boundaries already configured
**Testing:**
- Add/update Storybook stories for UI components (`pnpm storybook`)
- Run Playwright E2E tests with `pnpm test`
- Verify in Chromatic after PR
**Architecture:**
- Default to client components ("use client")
- Server components only for SEO or extreme TTFB needs
- Use React Query for server state (via generated hooks)
- Co-locate UI state in components/hooks
1. Components in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for component development
4. Test user-facing features with Playwright E2E tests
5. Update protected routes in middleware when needed
### Security Guidelines
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
- Prevents sensitive data caching in browsers/proxies
- Add new cacheable endpoints to `CACHEABLE_PATHS`
### CI/CD Alignment
The repository has comprehensive CI workflows that test:
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
- **Integration**: Full-stack type checking and E2E testing
@@ -305,7 +229,6 @@ Match these patterns when developing locally - the copilot setup environment mir
## Collaboration with Other AI Assistants
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
- Check for existing CLAUDE.md files that provide additional context
- Follow established patterns and conventions already in the codebase
- Maintain consistency with existing code style and architecture
@@ -314,9 +237,8 @@ This repository is actively developed with assistance from Claude (via CLAUDE.md
## Trust These Instructions
These instructions are comprehensive and tested. Only perform additional searches if:
1. Information here is incomplete for your specific task
2. You encounter errors not covered by the workarounds
3. You need to understand implementation details not covered above
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.

View File

@@ -1,3 +1,6 @@
[pr_reviewer]
num_code_suggestions=0
[pr_code_suggestions]
commitable_code_suggestions=false
num_code_suggestions=0

View File

@@ -63,9 +63,6 @@ poetry run pytest path/to/test.py --snapshot-update
# Install dependencies
cd frontend && pnpm i
# Generate API client from OpenAPI spec
pnpm generate:api
# Start development server
pnpm dev
@@ -78,23 +75,12 @@ pnpm storybook
# Build production
pnpm build
# Format and lint
pnpm format
# Type checking
pnpm types
```
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
**Key Frontend Conventions:**
- Separate render logic from data/behavior in components
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Use function declarations (not arrow functions) for components/handlers
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Only use Phosphor Icons
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
## Architecture Overview
@@ -109,16 +95,11 @@ pnpm types
### Frontend Architecture
- **Framework**: Next.js 15 App Router (client-first approach)
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
- **State Management**: React Query for server state, co-located UI state in components/hooks
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
- **Framework**: Next.js App Router with React Server Components
- **State Management**: React hooks + Supabase client for real-time updates
- **Workflow Builder**: Visual graph editor using @xyflow/react
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
- **Icons**: Phosphor Icons only
- **UI Components**: Radix UI primitives with Tailwind CSS styling
- **Feature Flags**: LaunchDarkly integration
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
- **Testing**: Playwright for E2E, Storybook for component development
### Key Concepts
@@ -172,7 +153,6 @@ Key models (defined in `/backend/schema.prisma`):
**Adding a new block:**
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
@@ -180,7 +160,6 @@ Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-
- File organization
Quick steps:
1. Create new file in `/backend/backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
@@ -201,20 +180,10 @@ ex: do the inputs and outputs tie well together?
**Frontend feature development:**
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
- Add `usePageName.ts` hook for logic
- Put sub-components in local `components/` folder
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Never use `src/components/__legacy__/*`
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
1. Components go in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for new components
4. Test with Playwright if user-facing
### Security Implementation

View File

@@ -1,57 +0,0 @@
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend
# Run just Supabase + Redis + RabbitMQ
start-core:
docker compose up -d deps
# Stop core services
stop-core:
docker compose stop deps
reset-db:
rm -rf db/docker/volumes/db/data
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
# View logs for core services
logs-core:
docker compose logs -f deps
# Run formatting and linting for backend and frontend
format:
cd backend && poetry run format
cd frontend && pnpm format
cd frontend && pnpm lint
init-env:
cp -n .env.default .env || true
cd backend && cp -n .env.default .env || true
cd frontend && cp -n .env.default .env || true
# Run migrations for backend
migrate:
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
run-backend:
cd backend && poetry run app
run-frontend:
cd frontend && pnpm dev
test-data:
cd backend && poetry run python test/test_data_creator.py
help:
@echo "Usage: make <target>"
@echo "Targets:"
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
@echo " stop-core - Stop the core services"
@echo " reset-db - Reset the database by deleting the volume"
@echo " logs-core - Tail the logs for core services"
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
@echo " migrate - Run backend database migrations"
@echo " run-backend - Run the backend FastAPI server"
@echo " run-frontend - Run the frontend Next.js development server"
@echo " test-data - Run the test data creator"

View File

@@ -38,37 +38,6 @@ To run the AutoGPT Platform, follow these steps:
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Running Just Core services
You can now run the following to enable just the core services.
```
# For help
make help
# Run just Supabase + Redis + RabbitMQ
make start-core
# Stop core services
make stop-core
# View logs from core services
make logs-core
# Run formatting and linting for backend and frontend
make format
# Run migrations for backend database
make migrate
# Run backend server
make run-backend
# Run frontend development server
make run-frontend
```
### Docker Compose Commands
Here are some useful Docker Compose commands for managing your AutoGPT Platform:

View File

@@ -10,7 +10,7 @@ from .jwt_utils import get_jwt_payload, verify_user
from .models import User
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
"""
FastAPI dependency that requires a valid authenticated user.
@@ -20,9 +20,7 @@ async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -
return verify_user(jwt_payload, admin_only=False)
async def requires_admin_user(
jwt_payload: dict = fastapi.Security(get_jwt_payload),
) -> User:
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
"""
FastAPI dependency that requires a valid admin user.
@@ -32,7 +30,7 @@ async def requires_admin_user(
return verify_user(jwt_payload, admin_only=True)
async def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
"""
FastAPI dependency that returns the ID of the authenticated user.

View File

@@ -45,7 +45,7 @@ class TestAuthDependencies:
"""Create a test client."""
return TestClient(app)
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
"""Test requires_user with valid JWT payload."""
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
@@ -53,12 +53,12 @@ class TestAuthDependencies:
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = await requires_user(jwt_payload)
user = requires_user(jwt_payload)
assert isinstance(user, User)
assert user.user_id == "user-123"
assert user.role == "user"
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
"""Test requires_user accepts admin users."""
jwt_payload = {
"sub": "admin-456",
@@ -69,28 +69,28 @@ class TestAuthDependencies:
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = await requires_user(jwt_payload)
user = requires_user(jwt_payload)
assert user.user_id == "admin-456"
assert user.role == "admin"
async def test_requires_user_missing_sub(self):
def test_requires_user_missing_sub(self):
"""Test requires_user with missing user ID."""
jwt_payload = {"role": "user", "email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
await requires_user(jwt_payload)
requires_user(jwt_payload)
assert exc_info.value.status_code == 401
assert "User ID not found" in exc_info.value.detail
async def test_requires_user_empty_sub(self):
def test_requires_user_empty_sub(self):
"""Test requires_user with empty user ID."""
jwt_payload = {"sub": "", "role": "user"}
with pytest.raises(HTTPException) as exc_info:
await requires_user(jwt_payload)
requires_user(jwt_payload)
assert exc_info.value.status_code == 401
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
"""Test requires_admin_user with admin role."""
jwt_payload = {
"sub": "admin-789",
@@ -101,51 +101,51 @@ class TestAuthDependencies:
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user = await requires_admin_user(jwt_payload)
user = requires_admin_user(jwt_payload)
assert user.user_id == "admin-789"
assert user.role == "admin"
async def test_requires_admin_user_with_regular_user(self):
def test_requires_admin_user_with_regular_user(self):
"""Test requires_admin_user rejects regular users."""
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
await requires_admin_user(jwt_payload)
requires_admin_user(jwt_payload)
assert exc_info.value.status_code == 403
assert "Admin access required" in exc_info.value.detail
async def test_requires_admin_user_missing_role(self):
def test_requires_admin_user_missing_role(self):
"""Test requires_admin_user with missing role."""
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
with pytest.raises(KeyError):
await requires_admin_user(jwt_payload)
requires_admin_user(jwt_payload)
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
"""Test get_user_id extracts user ID correctly."""
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
mocker.patch(
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
)
user_id = await get_user_id(jwt_payload)
user_id = get_user_id(jwt_payload)
assert user_id == "user-id-xyz"
async def test_get_user_id_missing_sub(self):
def test_get_user_id_missing_sub(self):
"""Test get_user_id with missing user ID."""
jwt_payload = {"role": "user"}
with pytest.raises(HTTPException) as exc_info:
await get_user_id(jwt_payload)
get_user_id(jwt_payload)
assert exc_info.value.status_code == 401
assert "User ID not found" in exc_info.value.detail
async def test_get_user_id_none_sub(self):
def test_get_user_id_none_sub(self):
"""Test get_user_id with None user ID."""
jwt_payload = {"sub": None, "role": "user"}
with pytest.raises(HTTPException) as exc_info:
await get_user_id(jwt_payload)
get_user_id(jwt_payload)
assert exc_info.value.status_code == 401
@@ -170,7 +170,7 @@ class TestAuthDependenciesIntegration:
return _create_token
async def test_endpoint_auth_enabled_no_token(self):
def test_endpoint_auth_enabled_no_token(self):
"""Test endpoints require token when auth is enabled."""
app = FastAPI()
@@ -184,7 +184,7 @@ class TestAuthDependenciesIntegration:
response = client.get("/test")
assert response.status_code == 401
async def test_endpoint_with_valid_token(self, create_token):
def test_endpoint_with_valid_token(self, create_token):
"""Test endpoint with valid JWT token."""
app = FastAPI()
@@ -203,7 +203,7 @@ class TestAuthDependenciesIntegration:
assert response.status_code == 200
assert response.json()["user_id"] == "test-user"
async def test_admin_endpoint_requires_admin_role(self, create_token):
def test_admin_endpoint_requires_admin_role(self, create_token):
"""Test admin endpoint rejects non-admin users."""
app = FastAPI()
@@ -240,7 +240,7 @@ class TestAuthDependenciesIntegration:
class TestAuthDependenciesEdgeCases:
"""Edge case tests for authentication dependencies."""
async def test_dependency_with_complex_payload(self):
def test_dependency_with_complex_payload(self):
"""Test dependencies handle complex JWT payloads."""
complex_payload = {
"sub": "user-123",
@@ -256,14 +256,14 @@ class TestAuthDependenciesEdgeCases:
"exp": 9999999999,
}
user = await requires_user(complex_payload)
user = requires_user(complex_payload)
assert user.user_id == "user-123"
assert user.email == "test@example.com"
admin = await requires_admin_user(complex_payload)
admin = requires_admin_user(complex_payload)
assert admin.role == "admin"
async def test_dependency_with_unicode_in_payload(self):
def test_dependency_with_unicode_in_payload(self):
"""Test dependencies handle unicode in JWT payloads."""
unicode_payload = {
"sub": "user-😀-123",
@@ -272,11 +272,11 @@ class TestAuthDependenciesEdgeCases:
"name": "日本語",
}
user = await requires_user(unicode_payload)
user = requires_user(unicode_payload)
assert "😀" in user.user_id
assert user.email == "测试@example.com"
async def test_dependency_with_null_values(self):
def test_dependency_with_null_values(self):
"""Test dependencies handle null values in payload."""
null_payload = {
"sub": "user-123",
@@ -286,18 +286,18 @@ class TestAuthDependenciesEdgeCases:
"metadata": None,
}
user = await requires_user(null_payload)
user = requires_user(null_payload)
assert user.user_id == "user-123"
assert user.email is None
async def test_concurrent_requests_isolation(self):
def test_concurrent_requests_isolation(self):
"""Test that concurrent requests don't interfere with each other."""
payload1 = {"sub": "user-1", "role": "user"}
payload2 = {"sub": "user-2", "role": "admin"}
# Simulate concurrent processing
user1 = await requires_user(payload1)
user2 = await requires_admin_user(payload2)
user1 = requires_user(payload1)
user2 = requires_admin_user(payload2)
assert user1.user_id == "user-1"
assert user2.user_id == "user-2"
@@ -314,7 +314,7 @@ class TestAuthDependenciesEdgeCases:
({"sub": "user", "role": "user"}, "Admin access required", True),
],
)
async def test_dependency_error_cases(
def test_dependency_error_cases(
self, payload, expected_error: str, admin_only: bool
):
"""Test that errors propagate correctly through dependencies."""
@@ -325,7 +325,7 @@ class TestAuthDependenciesEdgeCases:
verify_user(payload, admin_only=admin_only)
assert expected_error in exc_info.value.detail
async def test_dependency_valid_user(self):
def test_dependency_valid_user(self):
"""Test valid user case for dependency."""
# Import verify_user to test it directly since dependencies use FastAPI Security
from autogpt_libs.auth.jwt_utils import verify_user

View File

@@ -16,7 +16,7 @@ bearer_jwt_auth = HTTPBearer(
)
async def get_jwt_payload(
def get_jwt_payload(
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
) -> dict[str, Any]:
"""

View File

@@ -116,32 +116,32 @@ def test_parse_jwt_token_missing_audience():
assert "Invalid token" in str(exc_info.value)
async def test_get_jwt_payload_with_valid_token():
def test_get_jwt_payload_with_valid_token():
"""Test extracting JWT payload with valid bearer token."""
token = create_token(TEST_USER_PAYLOAD)
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
result = await jwt_utils.get_jwt_payload(credentials)
result = jwt_utils.get_jwt_payload(credentials)
assert result["sub"] == "test-user-id"
assert result["role"] == "user"
async def test_get_jwt_payload_no_credentials():
def test_get_jwt_payload_no_credentials():
"""Test JWT payload when no credentials provided."""
with pytest.raises(HTTPException) as exc_info:
await jwt_utils.get_jwt_payload(None)
jwt_utils.get_jwt_payload(None)
assert exc_info.value.status_code == 401
assert "Authorization header is missing" in exc_info.value.detail
async def test_get_jwt_payload_invalid_token():
def test_get_jwt_payload_invalid_token():
"""Test JWT payload extraction with invalid token."""
credentials = HTTPAuthorizationCredentials(
scheme="Bearer", credentials="invalid.token.here"
)
with pytest.raises(HTTPException) as exc_info:
await jwt_utils.get_jwt_payload(credentials)
jwt_utils.get_jwt_payload(credentials)
assert exc_info.value.status_code == 401
assert "Invalid token" in exc_info.value.detail

View File

@@ -4,7 +4,6 @@ import logging
import os
import socket
import sys
from logging.handlers import RotatingFileHandler
from pathlib import Path
from pydantic import Field, field_validator
@@ -94,36 +93,42 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
config = LoggingConfig()
log_handlers: list[logging.Handler] = []
structured_logging = config.enable_cloud_logging or force_cloud_logging
# Console output handlers
if not structured_logging:
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(config.level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
if config.level == logging.DEBUG:
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(config.level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
if config.level == logging.DEBUG:
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
if config.level == logging.DEBUG:
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
if config.level == logging.DEBUG:
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
log_handlers += [stdout, stderr]
log_handlers += [stdout, stderr]
# Cloud logging setup
else:
# Use Google Cloud Structured Log Handler. Log entries are printed to stdout
# in a JSON format which is automatically picked up by Google Cloud Logging.
from google.cloud.logging.handlers import StructuredLogHandler
if config.enable_cloud_logging or force_cloud_logging:
import google.cloud.logging
from google.cloud.logging.handlers import CloudLoggingHandler
from google.cloud.logging_v2.handlers.transports import (
BackgroundThreadTransport,
)
structured_log_handler = StructuredLogHandler(stream=sys.stdout)
structured_log_handler.setLevel(config.level)
log_handlers.append(structured_log_handler)
client = google.cloud.logging.Client()
# Use BackgroundThreadTransport to prevent blocking the main thread
# and deadlocks when gRPC calls to Google Cloud Logging hang
cloud_handler = CloudLoggingHandler(
client,
name="autogpt_logs",
transport=BackgroundThreadTransport,
)
cloud_handler.setLevel(config.level)
log_handlers.append(cloud_handler)
# File logging setup
if config.enable_file_logging:
@@ -134,13 +139,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
print(f"Log directory: {config.log_dir}")
# Activity log handler (INFO and above)
# Security fix: Use RotatingFileHandler with size limits to prevent disk exhaustion
activity_log_handler = RotatingFileHandler(
config.log_dir / LOG_FILE,
mode="a",
encoding="utf-8",
maxBytes=10 * 1024 * 1024, # 10MB per file
backupCount=3, # Keep 3 backup files (40MB total)
activity_log_handler = logging.FileHandler(
config.log_dir / LOG_FILE, "a", "utf-8"
)
activity_log_handler.setLevel(config.level)
activity_log_handler.setFormatter(
@@ -150,13 +150,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
if config.level == logging.DEBUG:
# Debug log handler (all levels)
# Security fix: Use RotatingFileHandler with size limits
debug_log_handler = RotatingFileHandler(
config.log_dir / DEBUG_LOG_FILE,
mode="a",
encoding="utf-8",
maxBytes=10 * 1024 * 1024, # 10MB per file
backupCount=3, # Keep 3 backup files (40MB total)
debug_log_handler = logging.FileHandler(
config.log_dir / DEBUG_LOG_FILE, "a", "utf-8"
)
debug_log_handler.setLevel(logging.DEBUG)
debug_log_handler.setFormatter(
@@ -165,13 +160,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
log_handlers.append(debug_log_handler)
# Error log handler (ERROR and above)
# Security fix: Use RotatingFileHandler with size limits
error_log_handler = RotatingFileHandler(
config.log_dir / ERROR_LOG_FILE,
mode="a",
encoding="utf-8",
maxBytes=10 * 1024 * 1024, # 10MB per file
backupCount=3, # Keep 3 backup files (40MB total)
error_log_handler = logging.FileHandler(
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
)
error_log_handler.setLevel(logging.ERROR)
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
@@ -179,13 +169,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
# Configure the root logger
logging.basicConfig(
format=(
"%(levelname)s %(message)s"
if structured_logging
else (
DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT
)
),
format=DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT,
level=config.level,
handlers=log_handlers,
)

View File

@@ -0,0 +1,328 @@
import asyncio
import inspect
import logging
import threading
import time
from functools import wraps
from typing import (
Any,
Callable,
ParamSpec,
Protocol,
TypeVar,
cast,
runtime_checkable,
)
P = ParamSpec("P")
R = TypeVar("R")
R_co = TypeVar("R_co", covariant=True)
logger = logging.getLogger(__name__)
def _make_hashable_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[Any, ...]:
"""
Convert args and kwargs into a hashable cache key.
Handles unhashable types like dict, list, set by converting them to
their sorted string representations.
"""
def make_hashable(obj: Any) -> Any:
"""Recursively convert an object to a hashable representation."""
if isinstance(obj, dict):
# Sort dict items to ensure consistent ordering
return (
"__dict__",
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
)
elif isinstance(obj, (list, tuple)):
return ("__list__", tuple(make_hashable(item) for item in obj))
elif isinstance(obj, set):
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
elif hasattr(obj, "__dict__"):
# Handle objects with __dict__ attribute
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
else:
# For basic hashable types (str, int, bool, None, etc.)
try:
hash(obj)
return obj
except TypeError:
# Fallback: convert to string representation
return ("__str__", str(obj))
hashable_args = tuple(make_hashable(arg) for arg in args)
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
return (hashable_args, hashable_kwargs)
@runtime_checkable
class CachedFunction(Protocol[P, R_co]):
"""Protocol for cached functions with cache management methods."""
def cache_clear(self) -> None:
"""Clear all cached entries."""
return None
def cache_info(self) -> dict[str, int | None]:
"""Get cache statistics."""
return {}
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
return False
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
"""Call the cached function."""
return None # type: ignore
def cached(
*,
maxsize: int = 128,
ttl_seconds: int | None = None,
) -> Callable[[Callable], CachedFunction]:
"""
Thundering herd safe cache decorator for both sync and async functions.
Uses double-checked locking to prevent multiple threads/coroutines from
executing the expensive operation simultaneously during cache misses.
Args:
func: The function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
ttl_seconds: Time to live in seconds. If None, entries never expire
Returns:
Decorated function or decorator
Example:
@cache() # Default: maxsize=128, no TTL
def expensive_sync_operation(param: str) -> dict:
return {"result": param}
@cache() # Works with async too
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
@cache(maxsize=1000, ttl_seconds=300) # Custom maxsize and TTL
def another_operation(param: str) -> dict:
return {"result": param}
"""
def decorator(target_func):
# Cache storage and locks
cache_storage = {}
if inspect.iscoroutinefunction(target_func):
# Async function with asyncio.Lock
cache_lock = asyncio.Lock()
@wraps(target_func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
return result
# Slow path: acquire lock for cache miss/expiry
async with cache_lock:
# Double-check: another coroutine might have populated cache
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = await target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
return result
wrapper = async_wrapper
else:
# Sync function with threading.Lock
cache_lock = threading.Lock()
@wraps(target_func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
return result
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
return result
wrapper = sync_wrapper
# Add cache management methods
def cache_clear() -> None:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
def cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if key in cache_storage:
del cache_storage[key]
return True
return False
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)
setattr(wrapper, "cache_delete", cache_delete)
return cast(CachedFunction, wrapper)
return decorator
def thread_cached(func):
"""
Thread-local cache decorator for both sync and async functions.
Each thread gets its own cache, which is useful for request-scoped caching
in web applications where you want to cache within a single request but
not across requests.
Args:
func: The function to cache
Returns:
Decorated function with thread-local caching
Example:
@thread_cached
def expensive_operation(param: str) -> dict:
return {"result": param}
@thread_cached # Works with async too
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
"""
thread_local = threading.local()
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
if inspect.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = _make_hashable_key(args, kwargs)
if key not in cache:
cache[key] = await func(*args, **kwargs)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = _make_hashable_key(args, kwargs)
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable) -> None:
"""Clear thread-local cache for a function."""
if clear := getattr(func, "clear_cache", None):
clear()

View File

@@ -16,7 +16,7 @@ from unittest.mock import Mock
import pytest
from backend.util.cache import cached, clear_thread_cache, thread_cached
from autogpt_libs.utils.cache import cached, clear_thread_cache, thread_cached
class TestThreadCached:
@@ -332,7 +332,7 @@ class TestCache:
"""Test basic sync caching functionality."""
call_count = 0
@cached(ttl_seconds=300)
@cached()
def expensive_sync_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
@@ -358,7 +358,7 @@ class TestCache:
"""Test basic async caching functionality."""
call_count = 0
@cached(ttl_seconds=300)
@cached()
async def expensive_async_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
@@ -385,7 +385,7 @@ class TestCache:
call_count = 0
results = []
@cached(ttl_seconds=300)
@cached()
def slow_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -412,7 +412,7 @@ class TestCache:
"""Test that concurrent async calls don't cause thundering herd."""
call_count = 0
@cached(ttl_seconds=300)
@cached()
async def slow_async_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -508,7 +508,7 @@ class TestCache:
"""Test cache clearing functionality."""
call_count = 0
@cached(ttl_seconds=300)
@cached()
def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -537,7 +537,7 @@ class TestCache:
"""Test cache clearing functionality with async function."""
call_count = 0
@cached(ttl_seconds=300)
@cached()
async def async_clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -567,7 +567,7 @@ class TestCache:
"""Test that cached async functions return actual results, not coroutines."""
call_count = 0
@cached(ttl_seconds=300)
@cached()
async def async_result_function(x: int) -> str:
nonlocal call_count
call_count += 1
@@ -593,7 +593,7 @@ class TestCache:
"""Test selective cache deletion functionality."""
call_count = 0
@cached(ttl_seconds=300)
@cached()
def deletable_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -636,7 +636,7 @@ class TestCache:
"""Test selective cache deletion functionality with async function."""
call_count = 0
@cached(ttl_seconds=300)
@cached()
async def async_deletable_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -674,450 +674,3 @@ class TestCache:
# Try to delete non-existent entry
was_deleted = async_deletable_function.cache_delete(99)
assert was_deleted is False
class TestSharedCache:
"""Tests for shared_cache (Redis-backed) functionality."""
def test_sync_shared_cache_basic(self):
"""Test basic shared cache functionality with sync function."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def shared_sync_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
return x + y
# Clear any existing cache
shared_sync_function.cache_clear()
# First call
result1 = shared_sync_function(10, 20)
assert result1 == 30
assert call_count == 1
# Second call - should use Redis cache
result2 = shared_sync_function(10, 20)
assert result2 == 30
assert call_count == 1
# Different args - should call function again
result3 = shared_sync_function(15, 25)
assert result3 == 40
assert call_count == 2
# Cleanup
shared_sync_function.cache_clear()
@pytest.mark.asyncio
async def test_async_shared_cache_basic(self):
"""Test basic shared cache functionality with async function."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
async def shared_async_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x + y
# Clear any existing cache
shared_async_function.cache_clear()
# First call
result1 = await shared_async_function(10, 20)
assert result1 == 30
assert call_count == 1
# Second call - should use Redis cache
result2 = await shared_async_function(10, 20)
assert result2 == 30
assert call_count == 1
# Different args - should call function again
result3 = await shared_async_function(15, 25)
assert result3 == 40
assert call_count == 2
# Cleanup
shared_async_function.cache_clear()
def test_shared_cache_ttl_refresh(self):
"""Test TTL refresh functionality with shared cache."""
call_count = 0
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=True)
def ttl_refresh_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 10
# Clear any existing cache
ttl_refresh_function.cache_clear()
# First call
result1 = ttl_refresh_function(3)
assert result1 == 30
assert call_count == 1
# Wait 1 second
time.sleep(1)
# Second call - should refresh TTL and use cache
result2 = ttl_refresh_function(3)
assert result2 == 30
assert call_count == 1
# Wait another 1.5 seconds (total 2.5s from first call, 1.5s from second)
time.sleep(1.5)
# Third call - TTL should have been refreshed, so still cached
result3 = ttl_refresh_function(3)
assert result3 == 30
assert call_count == 1
# Wait 2.1 seconds - now it should expire
time.sleep(2.1)
# Fourth call - should call function again
result4 = ttl_refresh_function(3)
assert result4 == 30
assert call_count == 2
# Cleanup
ttl_refresh_function.cache_clear()
def test_shared_cache_without_ttl_refresh(self):
"""Test that TTL doesn't refresh when refresh_ttl_on_get=False."""
call_count = 0
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=False)
def no_ttl_refresh_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 10
# Clear any existing cache
no_ttl_refresh_function.cache_clear()
# First call
result1 = no_ttl_refresh_function(4)
assert result1 == 40
assert call_count == 1
# Wait 1 second
time.sleep(1)
# Second call - should use cache but NOT refresh TTL
result2 = no_ttl_refresh_function(4)
assert result2 == 40
assert call_count == 1
# Wait another 1.1 seconds (total 2.1s from first call)
time.sleep(1.1)
# Third call - should have expired
result3 = no_ttl_refresh_function(4)
assert result3 == 40
assert call_count == 2
# Cleanup
no_ttl_refresh_function.cache_clear()
def test_shared_cache_complex_objects(self):
"""Test caching complex objects with shared cache (pickle serialization)."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def complex_object_function(x: int) -> dict:
nonlocal call_count
call_count += 1
return {
"number": x,
"squared": x**2,
"nested": {"list": [1, 2, x], "tuple": (x, x * 2)},
"string": f"value_{x}",
}
# Clear any existing cache
complex_object_function.cache_clear()
# First call
result1 = complex_object_function(5)
assert result1["number"] == 5
assert result1["squared"] == 25
assert result1["nested"]["list"] == [1, 2, 5]
assert call_count == 1
# Second call - should use cache
result2 = complex_object_function(5)
assert result2 == result1
assert call_count == 1
# Cleanup
complex_object_function.cache_clear()
def test_shared_cache_info(self):
"""Test cache_info for shared cache."""
@cached(ttl_seconds=30, shared_cache=True)
def info_shared_function(x: int) -> int:
return x * 2
# Clear any existing cache
info_shared_function.cache_clear()
# Check initial info
info = info_shared_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] is None # Redis manages size
assert info["ttl_seconds"] == 30
# Add some entries
info_shared_function(1)
info_shared_function(2)
info_shared_function(3)
info = info_shared_function.cache_info()
assert info["size"] == 3
# Cleanup
info_shared_function.cache_clear()
def test_shared_cache_delete(self):
"""Test selective deletion with shared cache."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def delete_shared_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 3
# Clear any existing cache
delete_shared_function.cache_clear()
# Add entries
delete_shared_function(1)
delete_shared_function(2)
delete_shared_function(3)
assert call_count == 3
# Verify cached
delete_shared_function(1)
delete_shared_function(2)
assert call_count == 3
# Delete specific entry
was_deleted = delete_shared_function.cache_delete(2)
assert was_deleted is True
# Entry for x=2 should be gone
delete_shared_function(2)
assert call_count == 4
# Others should still be cached
delete_shared_function(1)
delete_shared_function(3)
assert call_count == 4
# Try to delete non-existent
was_deleted = delete_shared_function.cache_delete(99)
assert was_deleted is False
# Cleanup
delete_shared_function.cache_clear()
@pytest.mark.asyncio
async def test_async_shared_cache_thundering_herd(self):
"""Test that shared cache prevents thundering herd for async functions."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
async def shared_slow_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1)
return x * x
# Clear any existing cache
shared_slow_function.cache_clear()
# Launch multiple concurrent tasks
tasks = [shared_slow_function(8) for _ in range(10)]
results = await asyncio.gather(*tasks)
# All should return same result
assert all(r == 64 for r in results)
# Only one should have executed
assert call_count == 1
# Cleanup
shared_slow_function.cache_clear()
def test_shared_cache_clear_pattern(self):
"""Test pattern-based cache clearing (Redis feature)."""
@cached(ttl_seconds=30, shared_cache=True)
def pattern_function(category: str, item: int) -> str:
return f"{category}_{item}"
# Clear any existing cache
pattern_function.cache_clear()
# Add various entries
pattern_function("fruit", 1)
pattern_function("fruit", 2)
pattern_function("vegetable", 1)
pattern_function("vegetable", 2)
info = pattern_function.cache_info()
assert info["size"] == 4
# Note: Pattern clearing with wildcards requires specific Redis scan
# implementation. The current code clears by pattern but needs
# adjustment for partial matching. For now, test full clear.
pattern_function.cache_clear()
info = pattern_function.cache_info()
assert info["size"] == 0
def test_shared_vs_local_cache_isolation(self):
"""Test that shared and local caches are isolated."""
shared_count = 0
local_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def shared_function(x: int) -> int:
nonlocal shared_count
shared_count += 1
return x * 2
@cached(ttl_seconds=30, shared_cache=False)
def local_function(x: int) -> int:
nonlocal local_count
local_count += 1
return x * 2
# Clear caches
shared_function.cache_clear()
local_function.cache_clear()
# Call both with same args
shared_result = shared_function(5)
local_result = local_function(5)
assert shared_result == local_result == 10
assert shared_count == 1
assert local_count == 1
# Call again - both should use their respective caches
shared_function(5)
local_function(5)
assert shared_count == 1
assert local_count == 1
# Clear only shared cache
shared_function.cache_clear()
# Shared should recompute, local should still use cache
shared_function(5)
local_function(5)
assert shared_count == 2
assert local_count == 1
# Cleanup
shared_function.cache_clear()
local_function.cache_clear()
@pytest.mark.asyncio
async def test_shared_cache_concurrent_different_keys(self):
"""Test that concurrent calls with different keys work correctly."""
call_counts = {}
@cached(ttl_seconds=30, shared_cache=True)
async def multi_key_function(key: str) -> str:
if key not in call_counts:
call_counts[key] = 0
call_counts[key] += 1
await asyncio.sleep(0.05)
return f"result_{key}"
# Clear cache
multi_key_function.cache_clear()
# Launch concurrent tasks with different keys
keys = ["a", "b", "c", "d", "e"]
tasks = []
for key in keys:
# Multiple calls per key
tasks.extend([multi_key_function(key) for _ in range(3)])
results = await asyncio.gather(*tasks)
# Verify results
for i, key in enumerate(keys):
expected = f"result_{key}"
# Each key appears 3 times in results
key_results = results[i * 3 : (i + 1) * 3]
assert all(r == expected for r in key_results)
# Each key should only be computed once
for key in keys:
assert call_counts[key] == 1
# Cleanup
multi_key_function.cache_clear()
def test_shared_cache_performance_comparison(self):
"""Compare performance of shared vs local cache."""
import statistics
shared_times = []
local_times = []
@cached(ttl_seconds=30, shared_cache=True)
def shared_perf_function(x: int) -> int:
time.sleep(0.01) # Simulate work
return x * 2
@cached(ttl_seconds=30, shared_cache=False)
def local_perf_function(x: int) -> int:
time.sleep(0.01) # Simulate work
return x * 2
# Clear caches
shared_perf_function.cache_clear()
local_perf_function.cache_clear()
# Warm up both caches
for i in range(5):
shared_perf_function(i)
local_perf_function(i)
# Measure cache hit times
for i in range(5):
# Shared cache hit
start = time.time()
shared_perf_function(i)
shared_times.append(time.time() - start)
# Local cache hit
start = time.time()
local_perf_function(i)
local_times.append(time.time() - start)
# Local cache should be faster (no Redis round-trip)
avg_shared = statistics.mean(shared_times)
avg_local = statistics.mean(local_times)
print(f"Avg shared cache hit time: {avg_shared:.6f}s")
print(f"Avg local cache hit time: {avg_local:.6f}s")
# Local should be significantly faster for cache hits
# Redis adds network latency even for cache hits
assert avg_local < avg_shared
# Cleanup
shared_perf_function.cache_clear()
local_perf_function.cache_clear()

View File

@@ -16,5 +16,4 @@ load-tests/*_RESULTS.md
load-tests/*_REPORT.md
load-tests/results/
load-tests/*.json
load-tests/*.log
load-tests/node_modules/*
load-tests/*.log

View File

@@ -47,7 +47,6 @@ RUN poetry install --no-ansi --no-root
# Generate Prisma client
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
RUN poetry run prisma generate
FROM debian:13-slim AS server_dependencies
@@ -93,7 +92,6 @@ FROM server_dependencies AS migrate
# Migration stage only needs schema and migrations - much lighter than full backend
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
FROM server_dependencies AS server

View File

@@ -5,7 +5,7 @@ import re
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
from backend.util.cache import cached
from autogpt_libs.utils.cache import cached
logger = logging.getLogger(__name__)
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
T = TypeVar("T")
@cached(ttl_seconds=3600)
@cached()
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
from backend.util.settings import Config

View File

@@ -1,214 +0,0 @@
from typing import Any
from backend.blocks.llm import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
AICredentials,
AICredentialsField,
LlmModel,
LLMResponse,
llm_call,
)
from backend.data.block import BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
class AIConditionBlock(AIBlockBase):
"""
An AI-powered condition block that uses natural language to evaluate conditions.
This block allows users to define conditions in plain English (e.g., "the input is an email address",
"the input is a city in the USA") and uses AI to determine if the input satisfies the condition.
It provides the same yes/no data pass-through functionality as the standard ConditionBlock.
"""
class Input(BlockSchema):
input_value: Any = SchemaField(
description="The input value to evaluate with the AI condition",
placeholder="Enter the value to be evaluated (text, number, or any data)",
)
condition: str = SchemaField(
description="A plaintext English description of the condition to evaluate",
placeholder="E.g., 'the input is the body of an email', 'the input is a City in the USA', 'the input is an error or a refusal'",
)
yes_value: Any = SchemaField(
description="(Optional) Value to output if the condition is true. If not provided, input_value will be used.",
placeholder="Leave empty to use input_value, or enter a specific value",
default=None,
)
no_value: Any = SchemaField(
description="(Optional) Value to output if the condition is false. If not provided, input_value will be used.",
placeholder="Leave empty to use input_value, or enter a specific value",
default=None,
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
description="The language model to use for evaluating the condition.",
advanced=False,
)
credentials: AICredentials = AICredentialsField()
class Output(BlockSchema):
result: bool = SchemaField(
description="The result of the AI condition evaluation (True or False)"
)
yes_output: Any = SchemaField(
description="The output value if the condition is true"
)
no_output: Any = SchemaField(
description="The output value if the condition is false"
)
error: str = SchemaField(
description="Error message if the AI evaluation is uncertain or fails"
)
def __init__(self):
super().__init__(
id="553ec5b8-6c45-4299-8d75-b394d05f72ff",
input_schema=AIConditionBlock.Input,
output_schema=AIConditionBlock.Output,
description="Uses AI to evaluate natural language conditions and provide conditional outputs",
categories={BlockCategory.AI, BlockCategory.LOGIC},
test_input={
"input_value": "john@example.com",
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": LlmModel.GPT4O,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("result", True),
("yes_output", "Valid email"),
],
test_mock={
"llm_call": lambda *args, **kwargs: LLMResponse(
raw_response="",
prompt=[],
response="true",
tool_calls=None,
prompt_tokens=50,
completion_tokens=10,
reasoning=None,
)
},
)
async def llm_call(
self,
credentials: APIKeyCredentials,
llm_model: LlmModel,
prompt: list,
max_tokens: int,
) -> LLMResponse:
"""Wrapper method for llm_call to enable mocking in tests."""
return await llm_call(
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
force_json_output=False,
max_tokens=max_tokens,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Evaluate the AI condition and return appropriate outputs.
"""
# Prepare the yes and no values, using input_value as default
yes_value = (
input_data.yes_value
if input_data.yes_value is not None
else input_data.input_value
)
no_value = (
input_data.no_value
if input_data.no_value is not None
else input_data.input_value
)
# Convert input_value to string for AI evaluation
input_str = str(input_data.input_value)
# Create the prompt for AI evaluation
prompt = [
{
"role": "system",
"content": (
"You are an AI assistant that evaluates conditions based on input data. "
"You must respond with only 'true' or 'false' (lowercase) to indicate whether "
"the given condition is met by the input value. Be accurate and consider the "
"context and meaning of both the input and the condition."
),
},
{
"role": "user",
"content": (
f"Input value: {input_str}\n"
f"Condition to evaluate: {input_data.condition}\n\n"
f"Does the input value satisfy the condition? Respond with only 'true' or 'false'."
),
},
]
# Call the LLM
try:
response = await self.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=10, # We only expect a true/false response
)
# Extract the boolean result from the response
response_text = response.response.strip().lower()
if response_text == "true":
result = True
elif response_text == "false":
result = False
else:
# If the response is not clear, try to interpret it using word boundaries
import re
# Use word boundaries to avoid false positives like 'untrue' or '10'
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
result = True
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
result = False
else:
# Unclear or conflicting response - default to False and yield error
result = False
yield "error", f"Unclear AI response: '{response.response}'"
# Update internal stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
)
)
self.prompt = response.prompt
except Exception as e:
# In case of any error, default to False to be safe
result = False
# Log the error but don't fail the block execution
import logging
logger = logging.getLogger(__name__)
logger.error(f"AI condition evaluation failed: {str(e)}")
yield "error", f"AI evaluation failed: {str(e)}"
# Yield results
yield "result", result
if result:
yield "yes_output", yes_value
else:
yield "no_output", no_value

View File

@@ -1,10 +1,8 @@
from enum import Enum
from typing import Any, Literal, Optional
from typing import Literal
from e2b_code_interpreter import AsyncSandbox
from e2b_code_interpreter import Result as E2BExecutionResult
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
from pydantic import BaseModel, JsonValue, SecretStr
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
@@ -38,135 +36,14 @@ class ProgrammingLanguage(Enum):
JAVA = "java"
class MainCodeExecutionResult(BaseModel):
"""
*Pydantic model mirroring `e2b_code_interpreter.Result`*
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
The result is similar to the structure returned by ipython kernel: https://ipython.readthedocs.io/en/stable/development/execution.html#execution-semantics
The result can contain multiple types of data, such as text, images, plots, etc. Each type of data is represented
as a string, and the result can contain multiple types of data. The display calls don't have to have text representation,
for the actual result the representation is always present for the result, the other representations are always optional.
""" # noqa
class Chart(BaseModel, E2BExecutionResultChart):
pass
text: Optional[str] = None
html: Optional[str] = None
markdown: Optional[str] = None
svg: Optional[str] = None
png: Optional[str] = None
jpeg: Optional[str] = None
pdf: Optional[str] = None
latex: Optional[str] = None
json: Optional[JsonValue] = None # type: ignore (reportIncompatibleMethodOverride)
javascript: Optional[str] = None
data: Optional[dict] = None
chart: Optional[Chart] = None
extra: Optional[dict] = None
"""Extra data that can be included. Not part of the standard types."""
class CodeExecutionResult(MainCodeExecutionResult):
__doc__ = MainCodeExecutionResult.__doc__
is_main_result: bool = False
"""Whether this data is the main result of the cell. Data can be produced by display calls of which can be multiple in a cell.""" # noqa
class BaseE2BExecutorMixin:
"""Shared implementation methods for E2B executor blocks."""
async def execute_code(
self,
api_key: str,
code: str,
language: ProgrammingLanguage,
template_id: str = "",
setup_commands: Optional[list[str]] = None,
timeout: Optional[int] = None,
sandbox_id: Optional[str] = None,
dispose_sandbox: bool = False,
):
"""
Unified code execution method that handles all three use cases:
1. Create new sandbox and execute (ExecuteCodeBlock)
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
""" # noqa
sandbox = None
try:
if sandbox_id:
# Connect to existing sandbox (ExecuteCodeStepBlock case)
sandbox = await AsyncSandbox.connect(
sandbox_id=sandbox_id, api_key=api_key
)
else:
# Create new sandbox (ExecuteCodeBlock/InstantiateCodeSandboxBlock case)
sandbox = await AsyncSandbox.create(
api_key=api_key, template=template_id, timeout=timeout
)
if setup_commands:
for cmd in setup_commands:
await sandbox.commands.run(cmd)
# Execute the code
execution = await sandbox.run_code(
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
)
if execution.error:
raise Exception(execution.error)
results = execution.results
text_output = execution.text
stdout_logs = "".join(execution.logs.stdout)
stderr_logs = "".join(execution.logs.stderr)
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
finally:
# Dispose of sandbox if requested to reduce usage costs
if dispose_sandbox and sandbox:
await sandbox.kill()
def process_execution_results(
self, results: list[E2BExecutionResult]
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
"""Process and filter execution results."""
# Filter out empty formats and convert to dicts
processed_results = [
{
f: value
for f in [*r.formats(), "extra", "is_main_result"]
if (value := getattr(r, f, None)) is not None
}
for r in results
]
if main_result := next(
(r for r in processed_results if r.get("is_main_result")), None
):
# Make main_result a copy we can modify & remove is_main_result
(main_result := {**main_result}).pop("is_main_result")
return main_result, processed_results
class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
class CodeExecutionBlock(Block):
# TODO : Add support to upload and download files
# NOTE: Currently, you can only customize the CPU and Memory
# by creating a pre customized sandbox template
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description=(
"Enter your API key for the E2B platform. "
"You can get it in here - https://e2b.dev/docs"
),
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
)
# Todo : Option to run commond in background
@@ -199,14 +76,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
description="Execution timeout in seconds", default=300
)
dispose_sandbox: bool = SchemaField(
description=(
"Whether to dispose of the sandbox immediately after execution. "
"If disabled, the sandbox will run until its timeout expires."
),
default=True,
)
template_id: str = SchemaField(
description=(
"You can use an E2B sandbox template by entering its ID here. "
@@ -218,16 +87,7 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
)
class Output(BlockSchema):
main_result: MainCodeExecutionResult = SchemaField(
title="Main Result", description="The main result from the code execution"
)
results: list[CodeExecutionResult] = SchemaField(
description="List of results from the code execution"
)
response: str = SchemaField(
title="Main Text Output",
description="Text output (if any) of the main execution result",
)
response: str = SchemaField(description="Response from code execution")
stdout_logs: str = SchemaField(
description="Standard output logs from execution"
)
@@ -237,10 +97,10 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
def __init__(self):
super().__init__(
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
description="Executes code in a sandbox environment with internet access.",
description="Executes code in an isolated sandbox environment with internet access.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=ExecuteCodeBlock.Input,
output_schema=ExecuteCodeBlock.Output,
input_schema=CodeExecutionBlock.Input,
output_schema=CodeExecutionBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
@@ -251,59 +111,91 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
"template_id": "",
},
test_output=[
("results", []),
("response", "Hello World"),
("stdout_logs", "Hello World\n"),
],
test_mock={
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
[], # results
"Hello World", # text_output
"Hello World\n", # stdout_logs
"", # stderr_logs
"sandbox_id", # sandbox_id
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
"Hello World",
"Hello World\n",
"",
),
},
)
async def execute_code(
self,
code: str,
language: ProgrammingLanguage,
setup_commands: list[str],
timeout: int,
api_key: str,
template_id: str,
):
try:
sandbox = None
if template_id:
sandbox = await AsyncSandbox.create(
template=template_id, api_key=api_key, timeout=timeout
)
else:
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
if not sandbox:
raise Exception("Sandbox not created")
# Running setup commands
for cmd in setup_commands:
await sandbox.commands.run(cmd)
# Executing the code
execution = await sandbox.run_code(
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
)
if execution.error:
raise Exception(execution.error)
response = execution.text
stdout_logs = "".join(execution.logs.stdout)
stderr_logs = "".join(execution.logs.stderr)
return response, stdout_logs, stderr_logs
except Exception as e:
raise e
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
results, text_output, stdout, stderr, _ = await self.execute_code(
api_key=credentials.api_key.get_secret_value(),
code=input_data.code,
language=input_data.language,
template_id=input_data.template_id,
setup_commands=input_data.setup_commands,
timeout=input_data.timeout,
dispose_sandbox=input_data.dispose_sandbox,
response, stdout_logs, stderr_logs = await self.execute_code(
input_data.code,
input_data.language,
input_data.setup_commands,
input_data.timeout,
credentials.api_key.get_secret_value(),
input_data.template_id,
)
# Determine result object shape & filter out empty formats
main_result, results = self.process_execution_results(results)
if main_result:
yield "main_result", main_result
yield "results", results
if text_output:
yield "response", text_output
if stdout:
yield "stdout_logs", stdout
if stderr:
yield "stderr_logs", stderr
if response:
yield "response", response
if stdout_logs:
yield "stdout_logs", stdout_logs
if stderr_logs:
yield "stderr_logs", stderr_logs
except Exception as e:
yield "error", str(e)
class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
class InstantiationBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description=(
"Enter your API key for the E2B platform. "
"You can get it in here - https://e2b.dev/docs"
)
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
)
# Todo : Option to run commond in background
@@ -348,10 +240,7 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
class Output(BlockSchema):
sandbox_id: str = SchemaField(description="ID of the sandbox instance")
response: str = SchemaField(
title="Text Result",
description="Text result (if any) of the setup code execution",
)
response: str = SchemaField(description="Response from code execution")
stdout_logs: str = SchemaField(
description="Standard output logs from execution"
)
@@ -361,13 +250,10 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
def __init__(self):
super().__init__(
id="ff0861c9-1726-4aec-9e5b-bf53f3622112",
description=(
"Instantiate a sandbox environment with internet access "
"in which you can execute code with the Execute Code Step block."
),
description="Instantiate an isolated sandbox environment with internet access where to execute code in.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=InstantiateCodeSandboxBlock.Input,
output_schema=InstantiateCodeSandboxBlock.Output,
input_schema=InstantiationBlock.Input,
output_schema=InstantiationBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
@@ -383,12 +269,11 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
("stdout_logs", "Hello World\n"),
],
test_mock={
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout: ( # noqa
[], # results
"Hello World", # text_output
"Hello World\n", # stdout_logs
"", # stderr_logs
"sandbox_id", # sandbox_id
"execute_code": lambda setup_code, language, setup_commands, timeout, api_key, template_id: (
"sandbox_id",
"Hello World",
"Hello World\n",
"",
),
},
)
@@ -397,38 +282,78 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
api_key=credentials.api_key.get_secret_value(),
code=input_data.setup_code,
language=input_data.language,
template_id=input_data.template_id,
setup_commands=input_data.setup_commands,
timeout=input_data.timeout,
sandbox_id, response, stdout_logs, stderr_logs = await self.execute_code(
input_data.setup_code,
input_data.language,
input_data.setup_commands,
input_data.timeout,
credentials.api_key.get_secret_value(),
input_data.template_id,
)
if sandbox_id:
yield "sandbox_id", sandbox_id
else:
yield "error", "Sandbox ID not found"
if text_output:
yield "response", text_output
if stdout:
yield "stdout_logs", stdout
if stderr:
yield "stderr_logs", stderr
if response:
yield "response", response
if stdout_logs:
yield "stdout_logs", stdout_logs
if stderr_logs:
yield "stderr_logs", stderr_logs
except Exception as e:
yield "error", str(e)
async def execute_code(
self,
code: str,
language: ProgrammingLanguage,
setup_commands: list[str],
timeout: int,
api_key: str,
template_id: str,
):
try:
sandbox = None
if template_id:
sandbox = await AsyncSandbox.create(
template=template_id, api_key=api_key, timeout=timeout
)
else:
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
if not sandbox:
raise Exception("Sandbox not created")
# Running setup commands
for cmd in setup_commands:
await sandbox.commands.run(cmd)
# Executing the code
execution = await sandbox.run_code(
code,
language=language.value,
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
)
if execution.error:
raise Exception(execution.error)
response = execution.text
stdout_logs = "".join(execution.logs.stdout)
stderr_logs = "".join(execution.logs.stderr)
return sandbox.sandbox_id, response, stdout_logs, stderr_logs
except Exception as e:
raise e
class StepExecutionBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description=(
"Enter your API key for the E2B platform. "
"You can get it in here - https://e2b.dev/docs"
),
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
)
sandbox_id: str = SchemaField(
@@ -449,22 +374,8 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
advanced=False,
)
dispose_sandbox: bool = SchemaField(
description="Whether to dispose of the sandbox after executing this code.",
default=False,
)
class Output(BlockSchema):
main_result: MainCodeExecutionResult = SchemaField(
title="Main Result", description="The main result from the code execution"
)
results: list[CodeExecutionResult] = SchemaField(
description="List of results from the code execution"
)
response: str = SchemaField(
title="Main Text Output",
description="Text output (if any) of the main execution result",
)
response: str = SchemaField(description="Response from code execution")
stdout_logs: str = SchemaField(
description="Standard output logs from execution"
)
@@ -474,10 +385,10 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
def __init__(self):
super().__init__(
id="82b59b8e-ea10-4d57-9161-8b169b0adba6",
description="Execute code in a previously instantiated sandbox.",
description="Execute code in a previously instantiated sandbox environment.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=ExecuteCodeStepBlock.Input,
output_schema=ExecuteCodeStepBlock.Output,
input_schema=StepExecutionBlock.Input,
output_schema=StepExecutionBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
@@ -486,43 +397,61 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
"language": ProgrammingLanguage.PYTHON.value,
},
test_output=[
("results", []),
("response", "Hello World"),
("stdout_logs", "Hello World\n"),
],
test_mock={
"execute_code": lambda api_key, code, language, sandbox_id, dispose_sandbox: ( # noqa
[], # results
"Hello World", # text_output
"Hello World\n", # stdout_logs
"", # stderr_logs
sandbox_id, # sandbox_id
"execute_step_code": lambda sandbox_id, step_code, language, api_key: (
"Hello World",
"Hello World\n",
"",
),
},
)
async def execute_step_code(
self,
sandbox_id: str,
code: str,
language: ProgrammingLanguage,
api_key: str,
):
try:
sandbox = await AsyncSandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
if not sandbox:
raise Exception("Sandbox not found")
# Executing the code
execution = await sandbox.run_code(code, language=language.value)
if execution.error:
raise Exception(execution.error)
response = execution.text
stdout_logs = "".join(execution.logs.stdout)
stderr_logs = "".join(execution.logs.stderr)
return response, stdout_logs, stderr_logs
except Exception as e:
raise e
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
results, text_output, stdout, stderr, _ = await self.execute_code(
api_key=credentials.api_key.get_secret_value(),
code=input_data.step_code,
language=input_data.language,
sandbox_id=input_data.sandbox_id,
dispose_sandbox=input_data.dispose_sandbox,
response, stdout_logs, stderr_logs = await self.execute_step_code(
input_data.sandbox_id,
input_data.step_code,
input_data.language,
credentials.api_key.get_secret_value(),
)
# Determine result object shape & filter out empty formats
main_result, results = self.process_execution_results(results)
if main_result:
yield "main_result", main_result
yield "results", results
if text_output:
yield "response", text_output
if stdout:
yield "stdout_logs", stdout
if stderr:
yield "stderr_logs", stderr
if response:
yield "response", response
if stdout_logs:
yield "stdout_logs", stdout_logs
if stderr_logs:
yield "stderr_logs", stderr_logs
except Exception as e:
yield "error", str(e)

View File

@@ -90,7 +90,7 @@ class CodeExtractionBlock(Block):
for aliases in language_aliases.values()
for alias in aliases
)
+ r")[ \t]*\n[\s\S]*?```"
+ r")\s+[\s\S]*?```"
)
remaining_text = re.sub(pattern, "", input_data.text).strip()
@@ -103,9 +103,7 @@ class CodeExtractionBlock(Block):
# Escape special regex characters in the language string
language = re.escape(language)
# Extract all code blocks enclosed in ```language``` blocks
pattern = re.compile(
rf"```{language}[ \t]*\n(.*?)\n```", re.DOTALL | re.IGNORECASE
)
pattern = re.compile(rf"```{language}\s+(.*?)```", re.DOTALL | re.IGNORECASE)
matches = pattern.finditer(text)
# Combine all code blocks for this language with newlines between them
code_blocks = [match.group(1).strip() for match in matches]

View File

@@ -66,7 +66,6 @@ class AddToDictionaryBlock(Block):
dictionary: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
advanced=False,
)
key: str = SchemaField(
default="",

View File

@@ -113,7 +113,6 @@ class DataForSeoClient:
include_serp_info: bool = False,
include_clickstream_data: bool = False,
limit: int = 100,
depth: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Get related keywords from DataForSEO Labs.
@@ -126,7 +125,6 @@ class DataForSeoClient:
include_serp_info: Include SERP data
include_clickstream_data: Include clickstream metrics
limit: Maximum number of results (up to 3000)
depth: Keyword search depth (0-4), controls number of returned keywords
Returns:
API response with related keywords
@@ -150,8 +148,6 @@ class DataForSeoClient:
task_data["include_clickstream_data"] = include_clickstream_data
if limit is not None:
task_data["limit"] = limit
if depth is not None:
task_data["depth"] = depth
payload = [task_data]

View File

@@ -90,7 +90,6 @@ class DataForSeoKeywordSuggestionsBlock(Block):
seed_keyword: str = SchemaField(
description="The seed keyword used for the query"
)
error: str = SchemaField(description="Error message if the API call failed")
def __init__(self):
super().__init__(
@@ -162,52 +161,43 @@ class DataForSeoKeywordSuggestionsBlock(Block):
**kwargs,
) -> BlockOutput:
"""Execute the keyword suggestions query."""
try:
client = DataForSeoClient(credentials)
client = DataForSeoClient(credentials)
results = await self._fetch_keyword_suggestions(client, input_data)
results = await self._fetch_keyword_suggestions(client, input_data)
# Process and format the results
suggestions = []
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
items = (
first_result.get("items", [])
if isinstance(first_result, dict)
else []
# Process and format the results
suggestions = []
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
items = (
first_result.get("items", []) if isinstance(first_result, dict) else []
)
for item in items:
# Create the KeywordSuggestion object
suggestion = KeywordSuggestion(
keyword=item.get("keyword", ""),
search_volume=item.get("keyword_info", {}).get("search_volume"),
competition=item.get("keyword_info", {}).get("competition"),
cpc=item.get("keyword_info", {}).get("cpc"),
keyword_difficulty=item.get("keyword_properties", {}).get(
"keyword_difficulty"
),
serp_info=(
item.get("serp_info") if input_data.include_serp_info else None
),
clickstream_data=(
item.get("clickstream_keyword_info")
if input_data.include_clickstream_data
else None
),
)
if items is None:
items = []
for item in items:
# Create the KeywordSuggestion object
suggestion = KeywordSuggestion(
keyword=item.get("keyword", ""),
search_volume=item.get("keyword_info", {}).get("search_volume"),
competition=item.get("keyword_info", {}).get("competition"),
cpc=item.get("keyword_info", {}).get("cpc"),
keyword_difficulty=item.get("keyword_properties", {}).get(
"keyword_difficulty"
),
serp_info=(
item.get("serp_info")
if input_data.include_serp_info
else None
),
clickstream_data=(
item.get("clickstream_keyword_info")
if input_data.include_clickstream_data
else None
),
)
yield "suggestion", suggestion
suggestions.append(suggestion)
yield "suggestion", suggestion
suggestions.append(suggestion)
yield "suggestions", suggestions
yield "total_count", len(suggestions)
yield "seed_keyword", input_data.keyword
except Exception as e:
yield "error", f"Failed to fetch keyword suggestions: {str(e)}"
yield "suggestions", suggestions
yield "total_count", len(suggestions)
yield "seed_keyword", input_data.keyword
class KeywordSuggestionExtractorBlock(Block):

View File

@@ -78,12 +78,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
ge=1,
le=3000,
)
depth: int = SchemaField(
description="Keyword search depth (0-4). Controls the number of returned keywords: 0=1 keyword, 1=~8 keywords, 2=~72 keywords, 3=~584 keywords, 4=~4680 keywords",
default=1,
ge=0,
le=4,
)
class Output(BlockSchema):
related_keywords: List[RelatedKeyword] = SchemaField(
@@ -98,7 +92,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
seed_keyword: str = SchemaField(
description="The seed keyword used for the query"
)
error: str = SchemaField(description="Error message if the API call failed")
def __init__(self):
super().__init__(
@@ -161,7 +154,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
include_serp_info=input_data.include_serp_info,
include_clickstream_data=input_data.include_clickstream_data,
limit=input_data.limit,
depth=input_data.depth,
)
async def run(
@@ -172,60 +164,50 @@ class DataForSeoRelatedKeywordsBlock(Block):
**kwargs,
) -> BlockOutput:
"""Execute the related keywords query."""
try:
client = DataForSeoClient(credentials)
client = DataForSeoClient(credentials)
results = await self._fetch_related_keywords(client, input_data)
results = await self._fetch_related_keywords(client, input_data)
# Process and format the results
related_keywords = []
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
items = (
first_result.get("items", [])
if isinstance(first_result, dict)
else []
# Process and format the results
related_keywords = []
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
items = (
first_result.get("items", []) if isinstance(first_result, dict) else []
)
for item in items:
# Extract keyword_data from the item
keyword_data = item.get("keyword_data", {})
# Create the RelatedKeyword object
keyword = RelatedKeyword(
keyword=keyword_data.get("keyword", ""),
search_volume=keyword_data.get("keyword_info", {}).get(
"search_volume"
),
competition=keyword_data.get("keyword_info", {}).get("competition"),
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
"keyword_difficulty"
),
serp_info=(
keyword_data.get("serp_info")
if input_data.include_serp_info
else None
),
clickstream_data=(
keyword_data.get("clickstream_keyword_info")
if input_data.include_clickstream_data
else None
),
)
# Ensure items is never None
if items is None:
items = []
for item in items:
# Extract keyword_data from the item
keyword_data = item.get("keyword_data", {})
yield "related_keyword", keyword
related_keywords.append(keyword)
# Create the RelatedKeyword object
keyword = RelatedKeyword(
keyword=keyword_data.get("keyword", ""),
search_volume=keyword_data.get("keyword_info", {}).get(
"search_volume"
),
competition=keyword_data.get("keyword_info", {}).get(
"competition"
),
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
keyword_difficulty=keyword_data.get(
"keyword_properties", {}
).get("keyword_difficulty"),
serp_info=(
keyword_data.get("serp_info")
if input_data.include_serp_info
else None
),
clickstream_data=(
keyword_data.get("clickstream_keyword_info")
if input_data.include_clickstream_data
else None
),
)
yield "related_keyword", keyword
related_keywords.append(keyword)
yield "related_keywords", related_keywords
yield "total_count", len(related_keywords)
yield "seed_keyword", input_data.keyword
except Exception as e:
yield "error", f"Failed to fetch related keywords: {str(e)}"
yield "related_keywords", related_keywords
yield "total_count", len(related_keywords)
yield "seed_keyword", input_data.keyword
class RelatedKeywordExtractorBlock(Block):

View File

@@ -4,13 +4,13 @@ import mimetypes
from pathlib import Path
from typing import Any
import aiohttp
import discord
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
from ._auth import (
@@ -114,9 +114,10 @@ class ReadDiscordMessagesBlock(Block):
if message.attachments:
attachment = message.attachments[0] # Process the first attachment
if attachment.filename.endswith((".txt", ".py")):
response = await Requests().get(attachment.url)
file_content = response.text()
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
async with aiohttp.ClientSession() as session:
async with session.get(attachment.url) as response:
file_content = response.text()
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
await client.close()
@@ -170,11 +171,11 @@ class SendDiscordMessageBlock(Block):
description="The content of the message to send"
)
channel_name: str = SchemaField(
description="Channel ID or channel name to send the message to"
description="The name of the channel the message will be sent to"
)
server_name: str = SchemaField(
description="Server name (only needed if using channel name)",
advanced=True,
description="The name of the server where the channel is located",
advanced=True, # Optional field for server name
default="",
)
@@ -230,49 +231,25 @@ class SendDiscordMessageBlock(Block):
@client.event
async def on_ready():
print(f"Logged in as {client.user}")
channel = None
for guild in client.guilds:
if server_name and guild.name != server_name:
continue
for channel in guild.text_channels:
if channel.name == channel_name:
# Split message into chunks if it exceeds 2000 characters
chunks = self.chunk_message(message_content)
last_message = None
for chunk in chunks:
last_message = await channel.send(chunk)
result["status"] = "Message sent"
result["message_id"] = (
str(last_message.id) if last_message else ""
)
result["channel_id"] = str(channel.id)
await client.close()
return
# Try to parse as channel ID first
try:
channel_id = int(channel_name)
channel = client.get_channel(channel_id)
except ValueError:
# Not a valid ID, will try name lookup
pass
# If not found by ID (or not an ID), try name lookup
if not channel:
for guild in client.guilds:
if server_name and guild.name != server_name:
continue
for ch in guild.text_channels:
if ch.name == channel_name:
channel = ch
break
if channel:
break
if not channel:
result["status"] = f"Channel not found: {channel_name}"
await client.close()
return
# Type check - ensure it's a text channel that can send messages
if not hasattr(channel, "send"):
result["status"] = (
f"Channel {channel_name} cannot receive messages (not a text channel)"
)
await client.close()
return
# Split message into chunks if it exceeds 2000 characters
chunks = self.chunk_message(message_content)
last_message = None
for chunk in chunks:
last_message = await channel.send(chunk) # type: ignore
result["status"] = "Message sent"
result["message_id"] = str(last_message.id) if last_message else ""
result["channel_id"] = str(channel.id)
result["status"] = "Channel not found"
await client.close()
await client.start(token)
@@ -698,15 +675,16 @@ class SendDiscordFileBlock(Block):
elif file.startswith(("http://", "https://")):
# URL - download the file
response = await Requests().get(file)
file_bytes = response.content
async with aiohttp.ClientSession() as session:
async with session.get(file) as response:
file_bytes = await response.read()
# Try to get filename from URL if not provided
if not filename:
from urllib.parse import urlparse
# Try to get filename from URL if not provided
if not filename:
from urllib.parse import urlparse
path = urlparse(file).path
detected_filename = Path(path).name or "download"
path = urlparse(file).path
detected_filename = Path(path).name or "download"
else:
# Local file path - read from stored media file
# This would be a path from a previous block's output

View File

@@ -1,12 +0,0 @@
from enum import Enum
class ScrapeFormat(Enum):
MARKDOWN = "markdown"
HTML = "html"
RAW_HTML = "rawHtml"
LINKS = "links"
SCREENSHOT = "screenshot"
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
JSON = "json"
CHANGE_TRACKING = "changeTracking"

View File

@@ -1,28 +0,0 @@
"""Utility functions for converting between our ScrapeFormat enum and firecrawl FormatOption types."""
from typing import List
from firecrawl.v2.types import FormatOption, ScreenshotFormat
from backend.blocks.firecrawl._api import ScrapeFormat
def convert_to_format_options(
formats: List[ScrapeFormat],
) -> List[FormatOption]:
"""Convert our ScrapeFormat enum values to firecrawl FormatOption types.
Handles special cases like screenshot@fullPage which needs to be converted
to a ScreenshotFormat object.
"""
result: List[FormatOption] = []
for format_enum in formats:
if format_enum.value == "screenshot@fullPage":
# Special case: convert to ScreenshotFormat with full_page=True
result.append(ScreenshotFormat(type="screenshot", full_page=True))
else:
# Regular string literals
result.append(format_enum.value)
return result

View File

@@ -1,9 +1,8 @@
from enum import Enum
from typing import Any
from firecrawl import FirecrawlApp
from firecrawl.v2.types import ScrapeOptions
from firecrawl import FirecrawlApp, ScrapeOptions
from backend.blocks.firecrawl._api import ScrapeFormat
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -15,10 +14,21 @@ from backend.sdk import (
)
from ._config import firecrawl
from ._format_utils import convert_to_format_options
class ScrapeFormat(Enum):
MARKDOWN = "markdown"
HTML = "html"
RAW_HTML = "rawHtml"
LINKS = "links"
SCREENSHOT = "screenshot"
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
JSON = "json"
CHANGE_TRACKING = "changeTracking"
class FirecrawlCrawlBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
url: str = SchemaField(description="The URL to crawl")
@@ -68,17 +78,18 @@ class FirecrawlCrawlBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
# Sync call
crawl_result = app.crawl(
crawl_result = app.crawl_url(
input_data.url,
limit=input_data.limit,
scrape_options=ScrapeOptions(
formats=convert_to_format_options(input_data.formats),
only_main_content=input_data.only_main_content,
max_age=input_data.max_age,
wait_for=input_data.wait_for,
formats=[format.value for format in input_data.formats],
onlyMainContent=input_data.only_main_content,
maxAge=input_data.max_age,
waitFor=input_data.wait_for,
),
)
yield "data", crawl_result.data
@@ -90,7 +101,7 @@ class FirecrawlCrawlBlock(Block):
elif f == ScrapeFormat.HTML:
yield "html", data.html
elif f == ScrapeFormat.RAW_HTML:
yield "raw_html", data.raw_html
yield "raw_html", data.rawHtml
elif f == ScrapeFormat.LINKS:
yield "links", data.links
elif f == ScrapeFormat.SCREENSHOT:
@@ -98,6 +109,6 @@ class FirecrawlCrawlBlock(Block):
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
yield "screenshot_full_page", data.screenshot
elif f == ScrapeFormat.CHANGE_TRACKING:
yield "change_tracking", data.change_tracking
yield "change_tracking", data.changeTracking
elif f == ScrapeFormat.JSON:
yield "json", data.json

View File

@@ -20,6 +20,7 @@ from ._config import firecrawl
@cost(BlockCost(2, BlockCostType.RUN))
class FirecrawlExtractBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
urls: list[str] = SchemaField(
@@ -52,6 +53,7 @@ class FirecrawlExtractBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
extract_result = app.extract(

View File

@@ -1,5 +1,3 @@
from typing import Any
from firecrawl import FirecrawlApp
from backend.sdk import (
@@ -16,16 +14,14 @@ from ._config import firecrawl
class FirecrawlMapWebsiteBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
url: str = SchemaField(description="The website url to map")
class Output(BlockSchema):
links: list[str] = SchemaField(description="List of URLs found on the website")
results: list[dict[str, Any]] = SchemaField(
description="List of search results with url, title, and description"
)
links: list[str] = SchemaField(description="The links of the website")
def __init__(self):
super().__init__(
@@ -39,22 +35,12 @@ class FirecrawlMapWebsiteBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
# Sync call
map_result = app.map(
map_result = app.map_url(
url=input_data.url,
)
# Convert SearchResult objects to dicts
results_data = [
{
"url": link.url,
"title": link.title,
"description": link.description,
}
for link in map_result.links
]
yield "links", [link.url for link in map_result.links]
yield "results", results_data
yield "links", map_result.links

View File

@@ -1,8 +1,8 @@
from enum import Enum
from typing import Any
from firecrawl import FirecrawlApp
from backend.blocks.firecrawl._api import ScrapeFormat
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -14,10 +14,21 @@ from backend.sdk import (
)
from ._config import firecrawl
from ._format_utils import convert_to_format_options
class ScrapeFormat(Enum):
MARKDOWN = "markdown"
HTML = "html"
RAW_HTML = "rawHtml"
LINKS = "links"
SCREENSHOT = "screenshot"
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
JSON = "json"
CHANGE_TRACKING = "changeTracking"
class FirecrawlScrapeBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
url: str = SchemaField(description="The URL to crawl")
@@ -67,11 +78,12 @@ class FirecrawlScrapeBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
scrape_result = app.scrape(
scrape_result = app.scrape_url(
input_data.url,
formats=convert_to_format_options(input_data.formats),
formats=[format.value for format in input_data.formats],
only_main_content=input_data.only_main_content,
max_age=input_data.max_age,
wait_for=input_data.wait_for,
@@ -84,7 +96,7 @@ class FirecrawlScrapeBlock(Block):
elif f == ScrapeFormat.HTML:
yield "html", scrape_result.html
elif f == ScrapeFormat.RAW_HTML:
yield "raw_html", scrape_result.raw_html
yield "raw_html", scrape_result.rawHtml
elif f == ScrapeFormat.LINKS:
yield "links", scrape_result.links
elif f == ScrapeFormat.SCREENSHOT:
@@ -92,6 +104,6 @@ class FirecrawlScrapeBlock(Block):
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
yield "screenshot_full_page", scrape_result.screenshot
elif f == ScrapeFormat.CHANGE_TRACKING:
yield "change_tracking", scrape_result.change_tracking
yield "change_tracking", scrape_result.changeTracking
elif f == ScrapeFormat.JSON:
yield "json", scrape_result.json

View File

@@ -1,9 +1,8 @@
from enum import Enum
from typing import Any
from firecrawl import FirecrawlApp
from firecrawl.v2.types import ScrapeOptions
from firecrawl import FirecrawlApp, ScrapeOptions
from backend.blocks.firecrawl._api import ScrapeFormat
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -15,10 +14,21 @@ from backend.sdk import (
)
from ._config import firecrawl
from ._format_utils import convert_to_format_options
class ScrapeFormat(Enum):
MARKDOWN = "markdown"
HTML = "html"
RAW_HTML = "rawHtml"
LINKS = "links"
SCREENSHOT = "screenshot"
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
JSON = "json"
CHANGE_TRACKING = "changeTracking"
class FirecrawlSearchBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
query: str = SchemaField(description="The query to search for")
@@ -51,6 +61,7 @@ class FirecrawlSearchBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
# Sync call
@@ -58,12 +69,11 @@ class FirecrawlSearchBlock(Block):
input_data.query,
limit=input_data.limit,
scrape_options=ScrapeOptions(
formats=convert_to_format_options(input_data.formats) or None,
max_age=input_data.max_age,
wait_for=input_data.wait_for,
formats=[format.value for format in input_data.formats],
maxAge=input_data.max_age,
waitFor=input_data.wait_for,
),
)
yield "data", scrape_result
if hasattr(scrape_result, "web") and scrape_result.web:
for site in scrape_result.web:
yield "site", site
for site in scrape_result.data:
yield "site", site

View File

@@ -10,6 +10,7 @@ from backend.util.settings import Config
from backend.util.text import TextFormatter
from backend.util.type import LongTextType, MediaFileType, ShortTextType
formatter = TextFormatter()
config = Config()
@@ -131,11 +132,6 @@ class AgentOutputBlock(Block):
default="",
advanced=True,
)
escape_html: bool = SchemaField(
default=False,
advanced=True,
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
)
advanced: bool = SchemaField(
description="Whether to treat the output as advanced.",
default=False,
@@ -197,7 +193,6 @@ class AgentOutputBlock(Block):
"""
if input_data.format:
try:
formatter = TextFormatter(autoescape=input_data.escape_html)
yield "output", formatter.format_string(
input_data.format, {input_data.name: input_data.value}
)
@@ -554,89 +549,6 @@ class AgentToggleInputBlock(AgentInputBlock):
)
class AgentTableInputBlock(AgentInputBlock):
"""
This block allows users to input data in a table format.
Configure the table columns at build time, then users can input
rows of data at runtime. Each row is output as a dictionary
with column names as keys.
"""
class Input(AgentInputBlock.Input):
value: Optional[list[dict[str, Any]]] = SchemaField(
description="The table data as a list of dictionaries.",
default=None,
advanced=False,
title="Default Value",
)
column_headers: list[str] = SchemaField(
description="Column headers for the table.",
default_factory=lambda: ["Column 1", "Column 2", "Column 3"],
advanced=False,
title="Column Headers",
)
def generate_schema(self):
"""Generate schema for the value field with table format."""
schema = super().generate_schema()
schema["type"] = "array"
schema["format"] = "table"
schema["items"] = {
"type": "object",
"properties": {
header: {"type": "string"}
for header in (
self.column_headers or ["Column 1", "Column 2", "Column 3"]
)
},
}
if self.value is not None:
schema["default"] = self.value
return schema
class Output(AgentInputBlock.Output):
result: list[dict[str, Any]] = SchemaField(
description="The table data as a list of dictionaries with headers as keys."
)
def __init__(self):
super().__init__(
id="5603b273-f41e-4020-af7d-fbc9c6a8d928",
description="Block for table data input with customizable headers.",
disabled=not config.enable_agent_input_subtype_blocks,
input_schema=AgentTableInputBlock.Input,
output_schema=AgentTableInputBlock.Output,
test_input=[
{
"name": "test_table",
"column_headers": ["Name", "Age", "City"],
"value": [
{"Name": "John", "Age": "30", "City": "New York"},
{"Name": "Jane", "Age": "25", "City": "London"},
],
"description": "Example table input",
}
],
test_output=[
(
"result",
[
{"Name": "John", "Age": "30", "City": "New York"},
{"Name": "Jane", "Age": "25", "City": "London"},
],
)
],
)
async def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
"""
Yields the table data as a list of dictionaries.
"""
# Pass through the value, defaulting to empty list if None
yield "result", input_data.value if input_data.value is not None else []
IO_BLOCK_IDs = [
AgentInputBlock().id,
AgentOutputBlock().id,
@@ -648,5 +560,4 @@ IO_BLOCK_IDs = [
AgentFileInputBlock().id,
AgentDropdownInputBlock().id,
AgentToggleInputBlock().id,
AgentTableInputBlock().id,
]

View File

@@ -2,7 +2,7 @@ from typing import Any
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.json import loads
from backend.util.json import json
class StepThroughItemsBlock(Block):
@@ -54,43 +54,20 @@ class StepThroughItemsBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Security fix: Add limits to prevent DoS from large iterations
MAX_ITEMS = 10000 # Maximum items to iterate
MAX_ITEM_SIZE = 1024 * 1024 # 1MB per item
for data in [input_data.items, input_data.items_object, input_data.items_str]:
if not data:
continue
# Limit string size before parsing
if isinstance(data, str):
if len(data) > MAX_ITEM_SIZE:
raise ValueError(
f"Input too large: {len(data)} bytes > {MAX_ITEM_SIZE} bytes"
)
items = loads(data)
items = json.loads(data)
else:
items = data
# Check total item count
if isinstance(items, (list, dict)):
if len(items) > MAX_ITEMS:
raise ValueError(f"Too many items: {len(items)} > {MAX_ITEMS}")
iteration_count = 0
if isinstance(items, dict):
# If items is a dictionary, iterate over its values
for key, value in items.items():
if iteration_count >= MAX_ITEMS:
break
yield "item", value
yield "key", key # Fixed: should yield key, not item
iteration_count += 1
for item in items.values():
yield "item", item
yield "key", item
else:
# If items is a list, iterate over the list
for index, item in enumerate(items):
if iteration_count >= MAX_ITEMS:
break
yield "item", item
yield "key", index
iteration_count += 1

View File

@@ -1,8 +1,5 @@
from typing import List
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
@@ -13,12 +10,6 @@ from backend.data.model import SchemaField
from backend.util.request import Requests
class Reference(TypedDict):
url: str
keyQuote: str
isSupportive: bool
class FactCheckerBlock(Block):
class Input(BlockSchema):
statement: str = SchemaField(
@@ -32,10 +23,6 @@ class FactCheckerBlock(Block):
)
result: bool = SchemaField(description="The result of the factuality check")
reason: str = SchemaField(description="The reason for the factuality result")
references: List[Reference] = SchemaField(
description="List of references supporting or contradicting the statement",
default=[],
)
error: str = SchemaField(description="Error message if the check fails")
def __init__(self):
@@ -66,11 +53,5 @@ class FactCheckerBlock(Block):
yield "factuality", data["factuality"]
yield "result", data["result"]
yield "reason", data["reason"]
# Yield references if present in the response
if "references" in data:
yield "references", data["references"]
else:
yield "references", []
else:
raise RuntimeError(f"Expected 'data' key not found in response: {data}")

View File

@@ -62,10 +62,10 @@ TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
title="Mock Linear API key",
username="mock-linear-username",
access_token=SecretStr("mock-linear-access-token"),
access_token_expires_at=1672531200, # Mock expiration time for short-lived token
access_token_expires_at=None,
refresh_token=SecretStr("mock-linear-refresh-token"),
refresh_token_expires_at=None,
scopes=["read", "write"],
scopes=["mock-linear-scopes"],
)
TEST_CREDENTIALS_API_KEY = APIKeyCredentials(

View File

@@ -2,9 +2,7 @@
Linear OAuth handler implementation.
"""
import base64
import json
import time
from typing import Optional
from urllib.parse import urlencode
@@ -40,9 +38,8 @@ class LinearOAuthHandler(BaseOAuthHandler):
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.auth_base_url = "https://linear.app/oauth/authorize"
self.token_url = "https://api.linear.app/oauth/token"
self.token_url = "https://api.linear.app/oauth/token" # Correct token URL
self.revoke_url = "https://api.linear.app/oauth/revoke"
self.migrate_url = "https://api.linear.app/oauth/migrate_old_token"
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
@@ -85,84 +82,19 @@ class LinearOAuthHandler(BaseOAuthHandler):
return True # Linear doesn't return JSON on successful revoke
async def migrate_old_token(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
"""
Migrate an old long-lived token to a new short-lived token with refresh token.
This uses Linear's /oauth/migrate_old_token endpoint to exchange current
long-lived tokens for short-lived tokens with refresh tokens without
requiring users to re-authorize.
"""
if not credentials.access_token:
raise ValueError("No access token to migrate")
request_body = {
"client_id": self.client_id,
"client_secret": self.client_secret,
}
headers = {
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
"Content-Type": "application/x-www-form-urlencoded",
}
response = await Requests().post(
self.migrate_url, data=request_body, headers=headers
)
if not response.ok:
try:
error_data = response.json()
error_message = error_data.get("error", "Unknown error")
error_description = error_data.get("error_description", "")
if error_description:
error_message = f"{error_message}: {error_description}"
except json.JSONDecodeError:
error_message = response.text
raise LinearAPIException(
f"Failed to migrate Linear token ({response.status}): {error_message}",
response.status,
)
token_data = response.json()
# Extract token expiration
now = int(time.time())
expires_in = token_data.get("expires_in")
access_token_expires_at = None
if expires_in:
access_token_expires_at = now + expires_in
new_credentials = OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=credentials.title,
username=credentials.username,
access_token=token_data["access_token"],
scopes=credentials.scopes, # Preserve original scopes
refresh_token=token_data.get("refresh_token"),
access_token_expires_at=access_token_expires_at,
refresh_token_expires_at=None,
)
new_credentials.id = credentials.id
return new_credentials
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
if not credentials.refresh_token:
raise ValueError(
"No refresh token available. Token may need to be migrated to the new refresh token system."
)
"No refresh token available."
) # Linear uses non-expiring tokens
return await self._request_tokens(
{
"refresh_token": credentials.refresh_token.get_secret_value(),
"grant_type": "refresh_token",
},
current_credentials=credentials,
}
)
async def _request_tokens(
@@ -170,33 +102,16 @@ class LinearOAuthHandler(BaseOAuthHandler):
params: dict[str, str],
current_credentials: Optional[OAuth2Credentials] = None,
) -> OAuth2Credentials:
# Determine if this is a refresh token request
is_refresh = params.get("grant_type") == "refresh_token"
# Build request body with appropriate grant_type
request_body = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "authorization_code", # Ensure grant_type is correct
**params,
}
# Set default grant_type if not provided
if "grant_type" not in request_body:
request_body["grant_type"] = "authorization_code"
headers = {"Content-Type": "application/x-www-form-urlencoded"}
# For refresh token requests, support HTTP Basic Authentication as recommended
if is_refresh:
# Option 1: Use HTTP Basic Auth (preferred by Linear)
client_credentials = f"{self.client_id}:{self.client_secret}"
encoded_credentials = base64.b64encode(client_credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
# Remove client credentials from body when using Basic Auth
request_body.pop("client_id", None)
request_body.pop("client_secret", None)
headers = {
"Content-Type": "application/x-www-form-urlencoded"
} # Correct header for token request
response = await Requests().post(
self.token_url, data=request_body, headers=headers
)
@@ -205,9 +120,6 @@ class LinearOAuthHandler(BaseOAuthHandler):
try:
error_data = response.json()
error_message = error_data.get("error", "Unknown error")
error_description = error_data.get("error_description", "")
if error_description:
error_message = f"{error_message}: {error_description}"
except json.JSONDecodeError:
error_message = response.text
raise LinearAPIException(
@@ -217,84 +129,27 @@ class LinearOAuthHandler(BaseOAuthHandler):
token_data = response.json()
# Extract token expiration if provided (for new refresh token implementation)
now = int(time.time())
expires_in = token_data.get("expires_in")
access_token_expires_at = None
if expires_in:
access_token_expires_at = now + expires_in
# Get username - preserve from current credentials if refreshing
username = None
if current_credentials and is_refresh:
username = current_credentials.username
elif "user" in token_data:
username = token_data["user"].get("name", "Unknown User")
else:
# Fetch username using the access token
username = await self._request_username(token_data["access_token"])
# Note: Linear access tokens do not expire, so we set expires_at to None
new_credentials = OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=current_credentials.title if current_credentials else None,
username=username or "Unknown User",
username=token_data.get("user", {}).get(
"name", "Unknown User"
), # extract name or set appropriate
access_token=token_data["access_token"],
scopes=(
token_data["scope"].split(",")
if "scope" in token_data
else (current_credentials.scopes if current_credentials else [])
),
refresh_token=token_data.get("refresh_token"),
access_token_expires_at=access_token_expires_at,
refresh_token_expires_at=None, # Linear doesn't provide refresh token expiration
scopes=token_data["scope"].split(
","
), # Linear returns comma-separated scopes
refresh_token=token_data.get(
"refresh_token"
), # Linear uses non-expiring tokens so this might be null
access_token_expires_at=None,
refresh_token_expires_at=None,
)
if current_credentials:
new_credentials.id = current_credentials.id
return new_credentials
async def get_access_token(self, credentials: OAuth2Credentials) -> str:
"""
Returns a valid access token, handling migration and refresh as needed.
This overrides the base implementation to handle Linear's token migration
from old long-lived tokens to new short-lived tokens with refresh tokens.
"""
# If token has no expiration and no refresh token, it might be an old token
# that needs migration
if (
credentials.access_token_expires_at is None
and credentials.refresh_token is None
):
try:
# Attempt to migrate the old token
migrated_credentials = await self.migrate_old_token(credentials)
# Update the credentials store would need to be handled by the caller
# For now, use the migrated credentials for this request
credentials = migrated_credentials
except LinearAPIException:
# Migration failed, try to use the old token as-is
# This maintains backward compatibility
pass
# Use the standard refresh logic from the base class
if self.needs_refresh(credentials):
credentials = await self.refresh_tokens(credentials)
return credentials.access_token.get_secret_value()
def needs_migration(self, credentials: OAuth2Credentials) -> bool:
"""
Check if credentials represent an old long-lived token that needs migration.
Old tokens have no expiration time and no refresh token.
"""
return (
credentials.access_token_expires_at is None
and credentials.refresh_token is None
)
async def _request_username(self, access_token: str) -> Optional[str]:
# Use the LinearClient to fetch user details using GraphQL
from ._api import LinearClient

View File

@@ -37,5 +37,5 @@ class Project(BaseModel):
name: str
description: str
priority: int
progress: float
content: str | None
progress: int
content: str

View File

@@ -1,9 +1,5 @@
# This file contains a lot of prompt block strings that would trigger "line too long"
# flake8: noqa: E501
import ast
import logging
import re
import secrets
from abc import ABC
from enum import Enum, EnumMeta
from json import JSONDecodeError
@@ -31,7 +27,7 @@ from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
fmt = TextFormatter(autoescape=False)
fmt = TextFormatter()
LLMProviderName = Literal[
ProviderName.AIML_API,
@@ -101,9 +97,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -208,20 +204,20 @@ MODEL_METADATA = {
"anthropic", 200000, 32000
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 32000
"anthropic", 200000, 8192
), # claude-4-opus-20250514
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000
"anthropic", 200000, 8192
), # claude-4-sonnet-20250514
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
"anthropic", 200000, 64000
), # claude-sonnet-4-5-20250929
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000
"anthropic", 200000, 8192
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-5-sonnet-20241022
LlmModel.CLAUDE_3_5_HAIKU: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-5-haiku-20241022
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096
), # claude-3-haiku-20240307
@@ -791,10 +787,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
default=False,
description=(
"Whether to force the LLM to produce a JSON-only response. "
"This can increase the block's reliability, "
"but may also reduce the quality of the response "
"because it prohibits the LLM from reasoning "
"before providing its JSON response."
"This can increase a model's reliability of outputting valid JSON. "
"However, it may also reduce the quality of the response, because it "
"prohibits the LLM from reasoning before providing its JSON response."
),
)
credentials: AICredentials = AICredentialsField()
@@ -865,18 +860,17 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
"llm_call": lambda *args, **kwargs: LLMResponse(
raw_response="",
prompt=[""],
response=(
'<json_output id="test123456">{\n'
' "key1": "key1Value",\n'
' "key2": "key2Value"\n'
"}</json_output>"
response=json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
),
tool_calls=None,
prompt_tokens=0,
completion_tokens=0,
reasoning=None,
),
"get_collision_proof_output_tag_id": lambda *args: "test123456",
)
},
)
@@ -913,6 +907,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
logger.debug(f"Calling LLM with input data: {input_data}")
prompt = [json.to_dict(p) for p in input_data.conversation_history]
def trim_prompt(s: str) -> str:
"""Removes indentation up to and including `|` from a multi-line prompt."""
lines = s.strip().split("\n")
return "\n".join([line.strip().lstrip("|") for line in lines])
values = input_data.prompt_values
if values:
input_data.prompt = fmt.format_string(input_data.prompt, values)
@@ -921,15 +920,24 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
if input_data.sys_prompt:
prompt.append({"role": "system", "content": input_data.sys_prompt})
# Use a one-time unique tag to prevent collisions with user/LLM content
output_tag_id = self.get_collision_proof_output_tag_id()
output_tag_start = f'<json_output id="{output_tag_id}">'
expected_json_type = "object" if not input_data.list_result else "array"
if input_data.expected_format:
sys_prompt = self.response_format_instructions(
input_data.expected_format,
list_mode=input_data.list_result,
pure_json_mode=input_data.force_json_output,
output_tag_start=output_tag_start,
expected_format = json.dumps(input_data.expected_format, indent=2)
if expected_json_type == "array":
indented_obj_format = expected_format.replace("\n", "\n ")
expected_format = f"[\n {indented_obj_format},\n ...\n]"
# Preserve indentation in prompt
expected_format = expected_format.replace("\n", "\n|")
sys_prompt = trim_prompt(
f"""
|You MUST respond with a valid JSON {expected_json_type} strictly following this format:
|{expected_format}
|
|If you cannot provide all the keys, you MUST provide an empty string for the values you cannot answer.
"""
)
prompt.append({"role": "system", "content": sys_prompt})
@@ -947,11 +955,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
except JSONDecodeError as e:
return f"JSON decode error: {e}"
logger.debug(f"LLM request: {prompt}")
error_feedback_message = ""
llm_model = input_data.model
for retry_count in range(input_data.retry):
logger.debug(f"LLM request: {prompt}")
try:
llm_response = await self.llm_call(
credentials=credentials,
@@ -976,53 +984,36 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
if input_data.expected_format:
try:
response_obj = self.get_json_from_response(
response_text,
pure_json_mode=input_data.force_json_output,
output_tag_start=output_tag_start,
json_strings = (
json.find_objects_in_text(response_text)
if expected_json_type == "object"
else json.find_arrays_in_text(response_text)
)
if not json_strings:
raise ValueError(
f"No JSON {expected_json_type} found in the response."
)
# Try to parse last JSON object/array in the response
response_obj = json.loads(json_strings.pop())
except (ValueError, JSONDecodeError) as parse_error:
censored_response = re.sub(r"[A-Za-z0-9]", "*", response_text)
response_snippet = (
f"{censored_response[:50]}...{censored_response[-30:]}"
)
logger.warning(
f"Error getting JSON from LLM response: {parse_error}\n\n"
f"Response start+end: `{response_snippet}`"
)
prompt.append({"role": "assistant", "content": response_text})
error_feedback_message = self.invalid_response_feedback(
parse_error,
was_parseable=False,
list_mode=input_data.list_result,
pure_json_mode=input_data.force_json_output,
output_tag_start=output_tag_start,
indented_parse_error = str(parse_error).replace("\n", "\n|")
error_feedback_message = trim_prompt(
f"""
|Your previous response did not contain a parseable JSON {expected_json_type}:
|
|{indented_parse_error}
|
|Please provide a valid JSON {expected_json_type} in your response that matches the expected format.
"""
)
prompt.append(
{"role": "user", "content": error_feedback_message}
)
continue
# Handle object response for `force_json_output`+`list_result`
if input_data.list_result and isinstance(response_obj, dict):
if "results" in response_obj and isinstance(
response_obj["results"], list
):
response_obj = response_obj["results"]
else:
error_feedback_message = (
"Expected an array of objects in the 'results' key, "
f"but got: {response_obj}"
)
prompt.append(
{"role": "assistant", "content": response_text}
)
prompt.append(
{"role": "user", "content": error_feedback_message}
)
continue
validation_errors = "\n".join(
[
validation_error
@@ -1047,12 +1038,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
return
prompt.append({"role": "assistant", "content": response_text})
error_feedback_message = self.invalid_response_feedback(
validation_errors,
was_parseable=True,
list_mode=input_data.list_result,
pure_json_mode=input_data.force_json_output,
output_tag_start=output_tag_start,
error_feedback_message = trim_prompt(
f"""
|Your response did not match the expected format:
|
|{validation_errors}
"""
)
prompt.append({"role": "user", "content": error_feedback_message})
else:
@@ -1084,127 +1075,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
raise RuntimeError(error_feedback_message)
def response_format_instructions(
self,
expected_object_format: dict[str, str],
*,
list_mode: bool,
pure_json_mode: bool,
output_tag_start: str,
) -> str:
expected_output_format = json.dumps(expected_object_format, indent=2)
output_type = "object" if not list_mode else "array"
outer_output_type = "object" if pure_json_mode else output_type
if output_type == "array":
indented_obj_format = expected_output_format.replace("\n", "\n ")
expected_output_format = f"[\n {indented_obj_format},\n ...\n]"
if pure_json_mode:
indented_list_format = expected_output_format.replace("\n", "\n ")
expected_output_format = (
"{\n"
' "reasoning": "... (optional)",\n' # for better performance
f' "results": {indented_list_format}\n'
"}"
)
# Preserve indentation in prompt
expected_output_format = expected_output_format.replace("\n", "\n|")
# Prepare prompt
if not pure_json_mode:
expected_output_format = (
f"{output_tag_start}\n{expected_output_format}\n</json_output>"
)
instructions = f"""
|In your response you MUST include a valid JSON {outer_output_type} strictly following this format:
|{expected_output_format}
|
|If you cannot provide all the keys, you MUST provide an empty string for the values you cannot answer.
""".strip()
if not pure_json_mode:
instructions += f"""
|
|You MUST enclose your final JSON answer in {output_tag_start}...</json_output> tags, even if the user specifies a different tag.
|There MUST be exactly ONE {output_tag_start}...</json_output> block in your response, which MUST ONLY contain the JSON {outer_output_type} and nothing else. Other text outside this block is allowed.
""".strip()
return trim_prompt(instructions)
def invalid_response_feedback(
self,
error,
*,
was_parseable: bool,
list_mode: bool,
pure_json_mode: bool,
output_tag_start: str,
) -> str:
outer_output_type = "object" if not list_mode or pure_json_mode else "array"
if was_parseable:
complaint = f"Your previous response did not match the expected {outer_output_type} format."
else:
complaint = f"Your previous response did not contain a parseable JSON {outer_output_type}."
indented_parse_error = str(error).replace("\n", "\n|")
instruction = (
f"Please provide a {output_tag_start}...</json_output> block containing a"
if not pure_json_mode
else "Please provide a"
) + f" valid JSON {outer_output_type} that matches the expected format."
return trim_prompt(
f"""
|{complaint}
|
|{indented_parse_error}
|
|{instruction}
"""
)
def get_json_from_response(
self, response_text: str, *, pure_json_mode: bool, output_tag_start: str
) -> dict[str, Any] | list[dict[str, Any]]:
if pure_json_mode:
# Handle pure JSON responses
try:
return json.loads(response_text)
except JSONDecodeError as first_parse_error:
# If that didn't work, try finding the { and } to deal with possible ```json fences etc.
json_start = response_text.find("{")
json_end = response_text.rfind("}")
try:
return json.loads(response_text[json_start : json_end + 1])
except JSONDecodeError:
# Raise the original error, as it's more likely to be relevant
raise first_parse_error from None
if output_tag_start not in response_text:
raise ValueError(
"Response does not contain the expected "
f"{output_tag_start}...</json_output> block."
)
json_output = (
response_text.split(output_tag_start, 1)[1]
.rsplit("</json_output>", 1)[0]
.strip()
)
return json.loads(json_output)
def get_collision_proof_output_tag_id(self) -> str:
return secrets.token_hex(8)
def trim_prompt(s: str) -> str:
"""Removes indentation up to and including `|` from a multi-line prompt."""
lines = s.strip().split("\n")
return "\n".join([line.strip().lstrip("|") for line in lines])
class AITextGeneratorBlock(AIBlockBase):
class Input(BlockSchema):
@@ -1400,27 +1270,11 @@ class AITextSummarizerBlock(AIBlockBase):
@staticmethod
def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]:
# Security fix: Add validation to prevent DoS attacks
# Limit text size to prevent memory exhaustion
MAX_TEXT_LENGTH = 1_000_000 # 1MB character limit
MAX_CHUNKS = 100 # Maximum number of chunks to prevent excessive memory use
if len(text) > MAX_TEXT_LENGTH:
text = text[:MAX_TEXT_LENGTH]
# Ensure chunk_size is at least 1 to prevent infinite loops
chunk_size = max(1, max_tokens - overlap)
# Ensure overlap is less than max_tokens to prevent invalid configurations
if overlap >= max_tokens:
overlap = max(0, max_tokens - 1)
words = text.split()
chunks = []
chunk_size = max_tokens - overlap
for i in range(0, len(words), chunk_size):
if len(chunks) >= MAX_CHUNKS:
break # Limit the number of chunks to prevent memory exhaustion
chunk = " ".join(words[i : i + max_tokens])
chunks.append(chunk)
@@ -1554,9 +1408,7 @@ class AIConversationBlock(AIBlockBase):
("prompt", list),
],
test_mock={
"llm_call": lambda *args, **kwargs: dict(
response="The 2020 World Series was played at Globe Life Field in Arlington, Texas."
)
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
},
)
@@ -1585,7 +1437,7 @@ class AIConversationBlock(AIBlockBase):
),
credentials=credentials,
)
yield "response", response["response"]
yield "response", response
yield "prompt", self.prompt

View File

@@ -1,226 +0,0 @@
# flake8: noqa: E501
import logging
from enum import Enum
from typing import Any, Literal
import openai
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
class PerplexityModel(str, Enum):
"""Perplexity sonar models available via OpenRouter"""
SONAR = "perplexity/sonar"
SONAR_PRO = "perplexity/sonar-pro"
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
PerplexityCredentials = CredentialsMetaInput[
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
]
TEST_CREDENTIALS = APIKeyCredentials(
id="test-perplexity-creds",
provider="open_router",
api_key=SecretStr("mock-openrouter-api-key"),
title="Mock OpenRouter API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def PerplexityCredentialsField() -> PerplexityCredentials:
return CredentialsField(
description="OpenRouter API key for accessing Perplexity models.",
)
class PerplexityBlock(Block):
class Input(BlockSchema):
prompt: str = SchemaField(
description="The query to send to the Perplexity model.",
placeholder="Enter your query here...",
)
model: PerplexityModel = SchemaField(
title="Perplexity Model",
default=PerplexityModel.SONAR,
description="The Perplexity sonar model to use.",
advanced=False,
)
credentials: PerplexityCredentials = PerplexityCredentialsField()
system_prompt: str = SchemaField(
title="System Prompt",
default="",
description="Optional system prompt to provide context to the model.",
advanced=True,
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate.",
)
class Output(BlockSchema):
response: str = SchemaField(
description="The response from the Perplexity model."
)
annotations: list[dict[str, Any]] = SchemaField(
description="List of URL citations and annotations from the response."
)
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
super().__init__(
id="c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f",
description="Query Perplexity's sonar models with real-time web search capabilities and receive annotated responses with source citations.",
categories={BlockCategory.AI, BlockCategory.SEARCH},
input_schema=PerplexityBlock.Input,
output_schema=PerplexityBlock.Output,
test_input={
"prompt": "What is the weather today?",
"model": PerplexityModel.SONAR,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("response", "The weather varies by location..."),
("annotations", list),
],
test_mock={
"call_perplexity": lambda *args, **kwargs: {
"response": "The weather varies by location...",
"annotations": [
{
"type": "url_citation",
"url_citation": {
"title": "weather.com",
"url": "https://weather.com",
},
}
],
}
},
)
self.execution_stats = NodeExecutionStats()
async def call_perplexity(
self,
credentials: APIKeyCredentials,
model: PerplexityModel,
prompt: str,
system_prompt: str = "",
max_tokens: int | None = None,
) -> dict[str, Any]:
"""Call Perplexity via OpenRouter and extract annotations."""
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
response = await client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=model.value,
messages=messages,
max_tokens=max_tokens,
)
if not response.choices:
raise ValueError("No response from Perplexity via OpenRouter.")
# Extract the response content
response_content = response.choices[0].message.content or ""
# Extract annotations if present in the message
annotations = []
if hasattr(response.choices[0].message, "annotations"):
# If annotations are directly available
annotations = response.choices[0].message.annotations
else:
# Check if there's a raw response with annotations
raw = getattr(response.choices[0].message, "_raw_response", None)
if isinstance(raw, dict) and "annotations" in raw:
annotations = raw["annotations"]
if not annotations and hasattr(response, "model_extra"):
# Check model_extra for annotations
model_extra = response.model_extra
if isinstance(model_extra, dict):
# Check in choices
if "choices" in model_extra and len(model_extra["choices"]) > 0:
choice = model_extra["choices"][0]
if "message" in choice and "annotations" in choice["message"]:
annotations = choice["message"]["annotations"]
# Also check the raw response object for annotations
if not annotations:
raw = getattr(response, "_raw_response", None)
if isinstance(raw, dict):
# Check various possible locations for annotations
if "annotations" in raw:
annotations = raw["annotations"]
elif "choices" in raw and len(raw["choices"]) > 0:
choice = raw["choices"][0]
if "message" in choice and "annotations" in choice["message"]:
annotations = choice["message"]["annotations"]
# Update execution stats
if response.usage:
self.execution_stats.input_token_count = response.usage.prompt_tokens
self.execution_stats.output_token_count = (
response.usage.completion_tokens
)
return {"response": response_content, "annotations": annotations or []}
except Exception as e:
logger.error(f"Error calling Perplexity: {e}")
raise
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
logger.debug(f"Running Perplexity block with model: {input_data.model}")
try:
result = await self.call_perplexity(
credentials=credentials,
model=input_data.model,
prompt=input_data.prompt,
system_prompt=input_data.system_prompt,
max_tokens=input_data.max_tokens,
)
yield "response", result["response"]
yield "annotations", result["annotations"]
except Exception as e:
error_msg = f"Error calling Perplexity: {str(e)}"
logger.error(error_msg)
yield "error", error_msg

View File

@@ -1,5 +1,4 @@
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from typing import Any
@@ -8,7 +7,6 @@ import pydantic
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import Requests
class RSSEntry(pydantic.BaseModel):
@@ -102,33 +100,8 @@ class ReadRSSFeedBlock(Block):
)
@staticmethod
async def parse_feed(url: str) -> dict[str, Any]:
# Security fix: Add protection against memory exhaustion attacks
MAX_FEED_SIZE = 10 * 1024 * 1024 # 10MB limit for RSS feeds
# Download feed content with size limit
try:
response = await Requests(raise_for_status=True).get(url)
# Check content length if available
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > MAX_FEED_SIZE:
raise ValueError(
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
)
# Get content with size limit
content = response.content
if len(content) > MAX_FEED_SIZE:
raise ValueError(f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit")
# Parse with feedparser using the validated content
# feedparser has built-in protection against XML attacks
return feedparser.parse(content) # type: ignore
except Exception as e:
# Log error and return empty feed
logging.warning(f"Failed to parse RSS feed from {url}: {e}")
return {"entries": []}
def parse_feed(url: str) -> dict[str, Any]:
return feedparser.parse(url) # type: ignore
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
keep_going = True
@@ -138,7 +111,7 @@ class ReadRSSFeedBlock(Block):
while keep_going:
keep_going = input_data.run_continuously
feed = await self.parse_feed(input_data.rss_url)
feed = self.parse_feed(input_data.rss_url)
all_entries = []
for entry in feed["entries"]:

View File

@@ -13,11 +13,6 @@ from backend.data.block import (
BlockSchema,
BlockType,
)
from backend.data.dynamic_fields import (
extract_base_field_name,
get_dynamic_field_description,
is_dynamic_field,
)
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json
from backend.util.clients import get_database_manager_async_client
@@ -103,22 +98,6 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
return {"role": "tool", "tool_call_id": call_id, "content": content}
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
"""
Safely convert raw_response to dictionary format for conversation history.
Handles different response types from different LLM providers.
"""
if isinstance(raw_response, str):
# Ollama returns a string, convert to dict format
return {"role": "assistant", "content": raw_response}
elif isinstance(raw_response, dict):
# Already a dict (from tests or some providers)
return raw_response
else:
# OpenAI/Anthropic return objects, convert with json.to_dict
return json.to_dict(raw_response)
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
"""
All the tool calls entry in the conversation history requires a response.
@@ -282,7 +261,6 @@ class SmartDecisionMakerBlock(Block):
@staticmethod
def cleanup(s: str):
"""Clean up block names for use as tool function names."""
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
@staticmethod
@@ -310,66 +288,41 @@ class SmartDecisionMakerBlock(Block):
}
sink_block_input_schema = block.input_schema
properties = {}
field_mapping = {} # clean_name -> original_name
for link in links:
field_name = link.sink_name
is_dynamic = is_dynamic_field(field_name)
# Clean property key to ensure Anthropic API compatibility for ALL fields
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
field_mapping[clean_field_name] = field_name
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
if is_dynamic:
# For dynamic fields, use cleaned name but preserve original in description
properties[clean_field_name] = {
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
# These are fields that get merged by the executor into their base field
if (
"_#_" in link.sink_name
or "_$_" in link.sink_name
or "_@_" in link.sink_name
):
# For dynamic fields, provide a generic string schema
# The executor will handle merging these into the appropriate structure
properties[sink_name] = {
"type": "string",
"description": get_dynamic_field_description(field_name),
"description": f"Dynamic value for {link.sink_name}",
}
else:
# For regular fields, use the block's schema directly
# For regular fields, use the block's schema
try:
properties[clean_field_name] = (
sink_block_input_schema.get_field_schema(field_name)
properties[sink_name] = sink_block_input_schema.get_field_schema(
link.sink_name
)
except (KeyError, AttributeError):
# If field doesn't exist in schema, provide a generic one
properties[clean_field_name] = {
# If the field doesn't exist in the schema, provide a generic schema
properties[sink_name] = {
"type": "string",
"description": f"Value for {field_name}",
"description": f"Value for {link.sink_name}",
}
# Build the parameters schema using a single unified path
base_schema = block.input_schema.jsonschema()
base_required = set(base_schema.get("required", []))
# Compute required fields at the leaf level:
# - If a linked field is dynamic and its base is required in the block schema, require the leaf
# - If a linked field is regular and is required in the block schema, require the leaf
required_fields: set[str] = set()
for link in links:
field_name = link.sink_name
is_dynamic = is_dynamic_field(field_name)
# Always use cleaned field name for property key (Anthropic API compliance)
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
if is_dynamic:
base_name = extract_base_field_name(field_name)
if base_name in base_required:
required_fields.add(clean_field_name)
else:
if field_name in base_required:
required_fields.add(clean_field_name)
tool_function["parameters"] = {
"type": "object",
**block.input_schema.jsonschema(),
"properties": properties,
"additionalProperties": False,
"required": sorted(required_fields),
}
# Store field mapping for later use in output processing
tool_function["_field_mapping"] = field_mapping
return {"type": "function", "function": tool_function}
@staticmethod
@@ -413,12 +366,13 @@ class SmartDecisionMakerBlock(Block):
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
link.sink_name, {}
)
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
description = (
sink_block_properties["description"]
if "description" in sink_block_properties
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name] = {
properties[sink_name] = {
"type": "string",
"description": description,
"default": json.dumps(sink_block_properties.get("default", None)),
@@ -434,17 +388,24 @@ class SmartDecisionMakerBlock(Block):
return {"type": "function", "function": tool_function}
@staticmethod
async def _create_function_signature(
node_id: str,
) -> list[dict[str, Any]]:
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
"""
Creates function signatures for connected tools.
Creates function signatures for tools linked to a specified node within a graph.
This method filters the graph links to identify those that are tools and are
connected to the given node_id. It then constructs function signatures for each
tool based on the metadata and input schema of the linked nodes.
Args:
node_id: The node_id for which to create function signatures.
Returns:
List of function signatures for tools
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
for a tool, including its name, description, and parameters.
Raises:
ValueError: If no tool links are found for the specified node_id, or if a sink node
or its metadata cannot be found.
"""
db_client = get_database_manager_async_client()
tools = [
@@ -469,116 +430,20 @@ class SmartDecisionMakerBlock(Block):
raise ValueError(f"Sink node not found: {links[0].sink_id}")
if sink_node.block_id == AgentExecutorBlock().id:
tool_func = (
return_tool_functions.append(
await SmartDecisionMakerBlock._create_agent_function_signature(
sink_node, links
)
)
return_tool_functions.append(tool_func)
else:
tool_func = (
return_tool_functions.append(
await SmartDecisionMakerBlock._create_block_function_signature(
sink_node, links
)
)
return_tool_functions.append(tool_func)
return return_tool_functions
async def _attempt_llm_call_with_validation(
self,
credentials: llm.APIKeyCredentials,
input_data: Input,
current_prompt: list[dict],
tool_functions: list[dict[str, Any]],
):
"""
Attempt a single LLM call with tool validation.
Returns the response if successful, raises ValueError if validation fails.
"""
resp = await llm.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=current_prompt,
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=input_data.multiple_tool_calls,
)
# Track LLM usage stats per call
self.merge_stats(
NodeExecutionStats(
input_token_count=resp.prompt_tokens,
output_token_count=resp.completion_tokens,
llm_call_count=1,
)
)
if not resp.tool_calls:
return resp
validation_errors_list: list[str] = []
for tool_call in resp.tool_calls:
tool_name = tool_call.function.name
try:
tool_args = json.loads(tool_call.function.arguments)
except Exception as e:
validation_errors_list.append(
f"Tool call '{tool_name}' has invalid JSON arguments: {e}"
)
continue
# Find the tool definition to get the expected arguments
tool_def = next(
(
tool
for tool in tool_functions
if tool["function"]["name"] == tool_name
),
None,
)
if tool_def is None and len(tool_functions) == 1:
tool_def = tool_functions[0]
# Get parameters schema from tool definition
if (
tool_def
and "function" in tool_def
and "parameters" in tool_def["function"]
):
parameters = tool_def["function"]["parameters"]
expected_args = parameters.get("properties", {})
required_params = set(parameters.get("required", []))
else:
expected_args = {arg: {} for arg in tool_args.keys()}
required_params = set()
# Validate tool call arguments
provided_args = set(tool_args.keys())
expected_args_set = set(expected_args.keys())
# Check for unexpected arguments (typos)
unexpected_args = provided_args - expected_args_set
# Only check for missing REQUIRED parameters
missing_required_args = required_params - provided_args
if unexpected_args or missing_required_args:
error_msg = f"Tool call '{tool_name}' has parameter errors:"
if unexpected_args:
error_msg += f" Unknown parameters: {sorted(unexpected_args)}."
if missing_required_args:
error_msg += f" Missing required parameters: {sorted(missing_required_args)}."
error_msg += f" Expected parameters: {sorted(expected_args_set)}."
if required_params:
error_msg += f" Required parameters: {sorted(required_params)}."
validation_errors_list.append(error_msg)
if validation_errors_list:
raise ValueError("; ".join(validation_errors_list))
return resp
async def run(
self,
input_data: Input,
@@ -601,19 +466,27 @@ class SmartDecisionMakerBlock(Block):
if pending_tool_calls and input_data.last_tool_output is None:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
# Only assign the last tool output to the first pending tool call
tool_output = []
if pending_tool_calls and input_data.last_tool_output is not None:
# Get the first pending tool call ID
first_call_id = next(iter(pending_tool_calls.keys()))
tool_output.append(
_create_tool_response(first_call_id, input_data.last_tool_output)
)
# Add tool output to prompt right away
prompt.extend(tool_output)
# Check if there are still pending tool calls after handling the first one
remaining_pending_calls = get_pending_tool_calls(prompt)
# If there are still pending tool calls, yield the conversation and return early
if remaining_pending_calls:
yield "conversations", prompt
return
# Fallback on adding tool output in the conversation history as user prompt.
elif input_data.last_tool_output:
logger.error(
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
@@ -646,33 +519,24 @@ class SmartDecisionMakerBlock(Block):
):
prompt.append({"role": "user", "content": prefix + input_data.prompt})
current_prompt = list(prompt)
max_attempts = max(1, int(input_data.retry))
response = None
response = await llm.llm_call(
credentials=credentials,
llm_model=input_data.model,
prompt=prompt,
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=input_data.multiple_tool_calls,
)
last_error = None
for attempt in range(max_attempts):
try:
response = await self._attempt_llm_call_with_validation(
credentials, input_data, current_prompt, tool_functions
)
break
except ValueError as e:
last_error = e
error_feedback = (
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
+ f"- {str(e)}\n"
+ "\nPlease make sure to use the exact parameter names as specified in the function schema."
)
current_prompt = list(current_prompt) + [
{"role": "user", "content": error_feedback}
]
if response is None:
raise last_error or ValueError(
"Failed to get valid response after all retry attempts"
# Track LLM usage stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
llm_call_count=1,
)
)
if not response.tool_calls:
yield "finished", response.response
@@ -682,6 +546,7 @@ class SmartDecisionMakerBlock(Block):
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
# Find the tool definition to get the expected arguments
tool_def = next(
(
tool
@@ -690,6 +555,7 @@ class SmartDecisionMakerBlock(Block):
),
None,
)
if (
tool_def
and "function" in tool_def
@@ -697,38 +563,20 @@ class SmartDecisionMakerBlock(Block):
):
expected_args = tool_def["function"]["parameters"].get("properties", {})
else:
expected_args = {arg: {} for arg in tool_args.keys()}
expected_args = tool_args.keys()
# Get field mapping from tool definition
field_mapping = (
tool_def.get("function", {}).get("_field_mapping", {})
if tool_def
else {}
)
for clean_arg_name in expected_args:
# arg_name is now always the cleaned field name (for Anthropic API compliance)
# Get the original field name from field mapping for proper emit key generation
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
arg_value = tool_args.get(clean_arg_name)
sanitized_tool_name = self.cleanup(tool_name)
sanitized_arg_name = self.cleanup(original_field_name)
emit_key = f"tools_^_{sanitized_tool_name}_~_{sanitized_arg_name}"
logger.debug(
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
graph_exec_id,
node_exec_id,
emit_key,
)
yield emit_key, arg_value
# Yield provided arguments and None for missing ones
for arg_name in expected_args:
if arg_name in tool_args:
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
else:
yield f"tools_^_{tool_name}_~_{arg_name}", None
# Add reasoning to conversation history if available
if response.reasoning:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
prompt.append(_convert_raw_response_to_dict(response.raw_response))
prompt.append(response.raw_response)
yield "conversations", prompt

View File

@@ -1,7 +1,6 @@
import logging
import signal
import threading
import warnings
from contextlib import contextmanager
from enum import Enum
@@ -27,13 +26,6 @@ from backend.sdk import (
SchemaField,
)
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
warnings.filterwarnings(
"ignore",
message="coroutine 'close_litellm_async_clients' was never awaited",
category=RuntimeWarning,
)
# Store the original method
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers

View File

@@ -19,7 +19,7 @@ async def test_block_ids_valid(block: Type[Block]):
# Skip list for blocks with known invalid UUIDs
skip_blocks = {
"GetWeatherInformationBlock",
"ExecuteCodeBlock",
"CodeExecutionBlock",
"CountdownTimerBlock",
"TwitterGetListTweetsBlock",
"TwitterRemoveListMemberBlock",

View File

@@ -1,269 +0,0 @@
"""
Test security fixes for various DoS vulnerabilities.
"""
import asyncio
from unittest.mock import patch
import pytest
from backend.blocks.code_extraction_block import CodeExtractionBlock
from backend.blocks.iteration import StepThroughItemsBlock
from backend.blocks.llm import AITextSummarizerBlock
from backend.blocks.text import ExtractTextInformationBlock
from backend.blocks.xml_parser import XMLParserBlock
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
class TestCodeExtractionBlockSecurity:
"""Test ReDoS fixes in CodeExtractionBlock."""
async def test_redos_protection(self):
"""Test that the regex patterns don't cause ReDoS."""
block = CodeExtractionBlock()
# Test with input that would previously cause ReDoS
malicious_input = "```python" + " " * 10000 # Large spaces
result = []
async for output_name, output_data in block.run(
CodeExtractionBlock.Input(text=malicious_input)
):
result.append((output_name, output_data))
# Should complete without hanging
assert len(result) >= 1
assert any(name == "remaining_text" for name, _ in result)
class TestAITextSummarizerBlockSecurity:
"""Test memory exhaustion fixes in AITextSummarizerBlock."""
def test_split_text_limits(self):
"""Test that _split_text has proper limits."""
# Test text size limit
large_text = "a" * 2_000_000 # 2MB text
result = AITextSummarizerBlock._split_text(large_text, 1000, 100)
# Should be truncated to 1MB
total_chars = sum(len(chunk) for chunk in result)
assert total_chars <= 1_000_000 + 1000 # Allow for chunk boundary
# Test chunk count limit
result = AITextSummarizerBlock._split_text("word " * 10000, 10, 9)
assert len(result) <= 100 # MAX_CHUNKS limit
# Test parameter validation
result = AITextSummarizerBlock._split_text(
"test", 10, 15
) # overlap > max_tokens
assert len(result) >= 1 # Should still work
class TestExtractTextInformationBlockSecurity:
"""Test ReDoS and memory exhaustion fixes in ExtractTextInformationBlock."""
async def test_text_size_limits(self):
"""Test text size limits."""
block = ExtractTextInformationBlock()
# Test with large input
large_text = "a" * 2_000_000 # 2MB
results = []
async for output_name, output_data in block.run(
ExtractTextInformationBlock.Input(
text=large_text, pattern=r"a+", find_all=True, group=0
)
):
results.append((output_name, output_data))
# Should complete and have limits applied
matched_results = [r for name, r in results if name == "matched_results"]
if matched_results:
assert len(matched_results[0]) <= 1000 # MAX_MATCHES limit
async def test_dangerous_pattern_timeout(self):
"""Test timeout protection for dangerous patterns."""
block = ExtractTextInformationBlock()
# Test with potentially dangerous lookahead pattern
test_input = "a" * 1000
# This should complete quickly due to timeout protection
start_time = asyncio.get_event_loop().time()
results = []
async for output_name, output_data in block.run(
ExtractTextInformationBlock.Input(
text=test_input, pattern=r"(?=.+)", find_all=True, group=0
)
):
results.append((output_name, output_data))
end_time = asyncio.get_event_loop().time()
# Should complete within reasonable time (much less than 5s timeout)
assert (end_time - start_time) < 10
async def test_redos_catastrophic_backtracking(self):
"""Test that ReDoS patterns with catastrophic backtracking are handled."""
block = ExtractTextInformationBlock()
# Pattern that causes catastrophic backtracking: (a+)+b
# With input "aaaaaaaaaaaaaaaaaaaaaaaaaaaa" (no 'b'), this causes exponential time
dangerous_pattern = r"(a+)+b"
test_input = "a" * 30 # 30 'a's without a 'b' at the end
# This should be handled by timeout protection or pattern detection
start_time = asyncio.get_event_loop().time()
results = []
async for output_name, output_data in block.run(
ExtractTextInformationBlock.Input(
text=test_input, pattern=dangerous_pattern, find_all=True, group=0
)
):
results.append((output_name, output_data))
end_time = asyncio.get_event_loop().time()
elapsed = end_time - start_time
# Should complete within timeout (6 seconds to be safe)
# The current threading.Timer approach doesn't work, so this will likely fail
# demonstrating the need for a fix
assert elapsed < 6, f"Regex took {elapsed}s, timeout mechanism failed"
# Should return empty results on timeout or no match
matched_results = [r for name, r in results if name == "matched_results"]
assert matched_results[0] == [] # No matches expected
class TestStepThroughItemsBlockSecurity:
"""Test iteration limits in StepThroughItemsBlock."""
async def test_item_count_limits(self):
"""Test maximum item count limits."""
block = StepThroughItemsBlock()
# Test with too many items
large_list = list(range(20000)) # Exceeds MAX_ITEMS (10000)
with pytest.raises(ValueError, match="Too many items"):
async for _ in block.run(StepThroughItemsBlock.Input(items=large_list)):
pass
async def test_string_size_limits(self):
"""Test string input size limits."""
block = StepThroughItemsBlock()
# Test with large JSON string
large_string = '["item"]' * 200000 # Large JSON string
with pytest.raises(ValueError, match="Input too large"):
async for _ in block.run(
StepThroughItemsBlock.Input(items_str=large_string)
):
pass
async def test_normal_iteration_works(self):
"""Test that normal iteration still works."""
block = StepThroughItemsBlock()
results = []
async for output_name, output_data in block.run(
StepThroughItemsBlock.Input(items=[1, 2, 3])
):
results.append((output_name, output_data))
# Should have 6 outputs (item, key for each of 3 items)
assert len(results) == 6
items = [data for name, data in results if name == "item"]
assert items == [1, 2, 3]
class TestXMLParserBlockSecurity:
"""Test XML size limits in XMLParserBlock."""
async def test_xml_size_limits(self):
"""Test XML input size limits."""
block = XMLParserBlock()
# Test with large XML - need to exceed 10MB limit
# Each "<item>data</item>" is 17 chars, need ~620K items for >10MB
large_xml = "<root>" + "<item>data</item>" * 620000 + "</root>"
with pytest.raises(ValueError, match="XML too large"):
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""
@patch("backend.util.file.scan_content_safe")
@patch("backend.util.file.get_cloud_storage_handler")
async def test_file_size_limits(self, mock_cloud_storage, mock_scan):
"""Test file size limits."""
# Mock cloud storage handler - get_cloud_storage_handler is async
# but is_cloud_path and parse_cloud_path are sync methods
from unittest.mock import MagicMock
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
# Make get_cloud_storage_handler an async function that returns the mock handler
async def async_get_handler():
return mock_handler
mock_cloud_storage.side_effect = async_get_handler
mock_scan.return_value = None
# Test with large base64 content
large_content = "a" * (200 * 1024 * 1024) # 200MB
large_data_uri = f"data:text/plain;base64,{large_content}"
with pytest.raises(ValueError, match="File too large"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(large_data_uri),
user_id="test_user",
)
@patch("backend.util.file.Path")
@patch("backend.util.file.scan_content_safe")
@patch("backend.util.file.get_cloud_storage_handler")
async def test_directory_size_limits(self, mock_cloud_storage, mock_scan, MockPath):
"""Test directory size limits."""
from unittest.mock import MagicMock
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
async def async_get_handler():
return mock_handler
mock_cloud_storage.side_effect = async_get_handler
mock_scan.return_value = None
# Create mock path instance for the execution directory
mock_path_instance = MagicMock()
mock_path_instance.exists.return_value = True
# Mock glob to return files that total > 1GB
mock_file = MagicMock()
mock_file.is_file.return_value = True
mock_file.stat.return_value.st_size = 2 * 1024 * 1024 * 1024 # 2GB
mock_path_instance.glob.return_value = [mock_file]
# Make Path() return our mock
MockPath.return_value = mock_path_instance
# Should raise an error when directory size exceeds limit
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(
"data:text/plain;base64,dGVzdA=="
), # Small test file
user_id="test_user",
)

View File

@@ -41,8 +41,6 @@ class TestLLMStatsTracking:
@pytest.mark.asyncio
async def test_ai_structured_response_block_tracks_stats(self):
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""
from unittest.mock import patch
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
@@ -52,7 +50,7 @@ class TestLLMStatsTracking:
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
response='{"key1": "value1", "key2": "value2"}',
tool_calls=None,
prompt_tokens=15,
completion_tokens=25,
@@ -70,12 +68,10 @@ class TestLLMStatsTracking:
)
outputs = {}
# Mock secrets.token_hex to return consistent ID
with patch("secrets.token_hex", return_value="test123456"):
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
# Check stats
assert block.execution_stats.input_token_count == 15
@@ -146,7 +142,7 @@ class TestLLMStatsTracking:
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"wrong": "format"}</json_output>',
response='{"wrong": "format"}',
tool_calls=None,
prompt_tokens=10,
completion_tokens=15,
@@ -157,7 +153,7 @@ class TestLLMStatsTracking:
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
response='{"key1": "value1", "key2": "value2"}',
tool_calls=None,
prompt_tokens=20,
completion_tokens=25,
@@ -176,12 +172,10 @@ class TestLLMStatsTracking:
)
outputs = {}
# Mock secrets.token_hex to return consistent ID
with patch("secrets.token_hex", return_value="test123456"):
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
# Check stats - should accumulate both calls
# For 2 attempts: attempt 1 (failed) + attempt 2 (success) = 2 total
@@ -274,8 +268,7 @@ class TestLLMStatsTracking:
mock_response.choices = [
MagicMock(
message=MagicMock(
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
tool_calls=None,
content='{"summary": "Test chunk summary"}', tool_calls=None
)
)
]
@@ -283,7 +276,7 @@ class TestLLMStatsTracking:
mock_response.choices = [
MagicMock(
message=MagicMock(
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
content='{"final_summary": "Test final summary"}',
tool_calls=None,
)
)
@@ -304,13 +297,11 @@ class TestLLMStatsTracking:
max_tokens=1000, # Large enough to avoid chunking
)
# Mock secrets.token_hex to return consistent ID
with patch("secrets.token_hex", return_value="test123456"):
outputs = {}
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
outputs = {}
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
print(f"Actual calls made: {call_count}")
print(f"Block stats: {block.execution_stats}")
@@ -362,7 +353,7 @@ class TestLLMStatsTracking:
assert block.execution_stats.llm_call_count == 1
# Check output
assert outputs["response"] == "AI response to conversation"
assert outputs["response"] == {"response": "AI response to conversation"}
@pytest.mark.asyncio
async def test_ai_list_generator_with_retries(self):
@@ -465,7 +456,7 @@ class TestLLMStatsTracking:
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"result": "test"}</json_output>',
response='{"result": "test"}',
tool_calls=None,
prompt_tokens=10,
completion_tokens=20,
@@ -484,12 +475,10 @@ class TestLLMStatsTracking:
# Run the block
outputs = {}
# Mock secrets.token_hex to return consistent ID
with patch("secrets.token_hex", return_value="test123456"):
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
async for output_name, output_data in block.run(
input_data, credentials=llm.TEST_CREDENTIALS
):
outputs[output_name] = output_data
# Block finished - now grab and assert stats
assert block.execution_stats is not None

View File

@@ -216,17 +216,8 @@ async def test_smart_decision_maker_tracks_llm_stats():
}
# Mock the _create_function_signature method to avoid database calls
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
new_callable=AsyncMock,
return_value=[],
with patch("backend.blocks.llm.llm_call", return_value=mock_response), patch.object(
SmartDecisionMakerBlock, "_create_function_signature", return_value=[]
):
# Create test input
@@ -258,471 +249,3 @@ async def test_smart_decision_maker_tracks_llm_stats():
# Verify outputs
assert "finished" in outputs # Should have finished since no tool calls
assert outputs["finished"] == "I need to think about this."
@pytest.mark.asyncio
async def test_smart_decision_maker_parameter_validation():
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool functions with specific parameter schema
mock_tool_functions = [
{
"type": "function",
"function": {
"name": "search_keywords",
"description": "Search for keywords with difficulty filtering",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
"max_keyword_difficulty": {
"type": "integer",
"description": "Maximum keyword difficulty (required)",
},
"optional_param": {
"type": "string",
"description": "Optional parameter with default",
"default": "default_value",
},
},
"required": ["query", "max_keyword_difficulty"],
},
},
}
]
# Test case 1: Tool call with TYPO in parameter name (should retry and eventually fail)
mock_tool_call_with_typo = MagicMock()
mock_tool_call_with_typo.function.name = "search_keywords"
mock_tool_call_with_typo.function.arguments = '{"query": "test", "maximum_keyword_difficulty": 50}' # TYPO: maximum instead of max
mock_response_with_typo = MagicMock()
mock_response_with_typo.response = None
mock_response_with_typo.tool_calls = [mock_tool_call_with_typo]
mock_response_with_typo.prompt_tokens = 50
mock_response_with_typo.completion_tokens = 25
mock_response_with_typo.reasoning = None
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_with_typo,
) as mock_llm_call, patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, # Set retry to 2 for testing
)
# Should raise ValueError after retries due to typo'd parameter name
with pytest.raises(ValueError) as exc_info:
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
):
outputs[output_name] = output_data
# Verify error message contains details about the typo
error_msg = str(exc_info.value)
assert "Tool call 'search_keywords' has parameter errors" in error_msg
assert "Unknown parameters: ['maximum_keyword_difficulty']" in error_msg
# Verify that LLM was called the expected number of times (retries)
assert mock_llm_call.call_count == 2 # Should retry based on input_data.retry
# Test case 2: Tool call missing REQUIRED parameter (should raise ValueError)
mock_tool_call_missing_required = MagicMock()
mock_tool_call_missing_required.function.name = "search_keywords"
mock_tool_call_missing_required.function.arguments = (
'{"query": "test"}' # Missing required max_keyword_difficulty
)
mock_response_missing_required = MagicMock()
mock_response_missing_required.response = None
mock_response_missing_required.tool_calls = [mock_tool_call_missing_required]
mock_response_missing_required.prompt_tokens = 50
mock_response_missing_required.completion_tokens = 25
mock_response_missing_required.reasoning = None
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_missing_required,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
)
# Should raise ValueError due to missing required parameter
with pytest.raises(ValueError) as exc_info:
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
):
outputs[output_name] = output_data
error_msg = str(exc_info.value)
assert "Tool call 'search_keywords' has parameter errors" in error_msg
assert "Missing required parameters: ['max_keyword_difficulty']" in error_msg
# Test case 3: Valid tool call with OPTIONAL parameter missing (should succeed)
mock_tool_call_valid = MagicMock()
mock_tool_call_valid.function.name = "search_keywords"
mock_tool_call_valid.function.arguments = '{"query": "test", "max_keyword_difficulty": 50}' # optional_param missing, but that's OK
mock_response_valid = MagicMock()
mock_response_valid.response = None
mock_response_valid.tool_calls = [mock_tool_call_valid]
mock_response_valid.prompt_tokens = 50
mock_response_valid.completion_tokens = 25
mock_response_valid.reasoning = None
mock_response_valid.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_valid,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
)
# Should succeed - optional parameter missing is OK
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
):
outputs[output_name] = output_data
# Verify tool outputs were generated correctly
assert "tools_^_search_keywords_~_query" in outputs
assert outputs["tools_^_search_keywords_~_query"] == "test"
assert "tools_^_search_keywords_~_max_keyword_difficulty" in outputs
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
# Optional parameter should be None when not provided
assert "tools_^_search_keywords_~_optional_param" in outputs
assert outputs["tools_^_search_keywords_~_optional_param"] is None
# Test case 4: Valid tool call with ALL parameters (should succeed)
mock_tool_call_all_params = MagicMock()
mock_tool_call_all_params.function.name = "search_keywords"
mock_tool_call_all_params.function.arguments = '{"query": "test", "max_keyword_difficulty": 50, "optional_param": "custom_value"}'
mock_response_all_params = MagicMock()
mock_response_all_params.response = None
mock_response_all_params.tool_calls = [mock_tool_call_all_params]
mock_response_all_params.prompt_tokens = 50
mock_response_all_params.completion_tokens = 25
mock_response_all_params.reasoning = None
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_all_params,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
)
# Should succeed with all parameters
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
):
outputs[output_name] = output_data
# Verify all tool outputs were generated correctly
assert outputs["tools_^_search_keywords_~_query"] == "test"
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
@pytest.mark.asyncio
async def test_smart_decision_maker_raw_response_conversion():
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool functions
mock_tool_functions = [
{
"type": "function",
"function": {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
},
}
]
# Test case 1: Simulate ChatCompletionMessage raw_response that caused the original error
class MockChatCompletionMessage:
"""Simulate OpenAI's ChatCompletionMessage object that lacks .get() method"""
def __init__(self, role, content, tool_calls=None):
self.role = role
self.content = content
self.tool_calls = tool_calls or []
# This is what caused the error - no .get() method
# def get(self, key, default=None): # Intentionally missing
# First response: has invalid parameter name (triggers retry)
mock_tool_call_invalid = MagicMock()
mock_tool_call_invalid.function.name = "test_tool"
mock_tool_call_invalid.function.arguments = (
'{"wrong_param": "test_value"}' # Invalid parameter name
)
mock_response_retry = MagicMock()
mock_response_retry.response = None
mock_response_retry.tool_calls = [mock_tool_call_invalid]
mock_response_retry.prompt_tokens = 50
mock_response_retry.completion_tokens = 25
mock_response_retry.reasoning = None
# This would cause the original error without our fix
mock_response_retry.raw_response = MockChatCompletionMessage(
role="assistant", content=None, tool_calls=[mock_tool_call_invalid]
)
# Second response: successful (correct parameter name)
mock_tool_call_valid = MagicMock()
mock_tool_call_valid.function.name = "test_tool"
mock_tool_call_valid.function.arguments = (
'{"param": "test_value"}' # Correct parameter name
)
mock_response_success = MagicMock()
mock_response_success.response = None
mock_response_success.tool_calls = [mock_tool_call_valid]
mock_response_success.prompt_tokens = 50
mock_response_success.completion_tokens = 25
mock_response_success.reasoning = None
mock_response_success.raw_response = MockChatCompletionMessage(
role="assistant", content=None, tool_calls=[mock_tool_call_valid]
)
# Mock llm_call to return different responses on different calls
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call", new_callable=AsyncMock
) as mock_llm_call, patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
# First call returns response that will trigger retry due to validation error
# Second call returns successful response
mock_llm_call.side_effect = [mock_response_retry, mock_response_success]
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
# Should succeed after retry, demonstrating our helper function works
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
):
outputs[output_name] = output_data
# Verify the tool output was generated successfully
assert "tools_^_test_tool_~_param" in outputs
assert outputs["tools_^_test_tool_~_param"] == "test_value"
# Verify conversation history was properly maintained
assert "conversations" in outputs
conversations = outputs["conversations"]
assert len(conversations) > 0
# The conversations should contain properly converted raw_response objects as dicts
# This would have failed with the original bug due to ChatCompletionMessage.get() error
for msg in conversations:
assert isinstance(msg, dict), f"Expected dict, got {type(msg)}"
if msg.get("role") == "assistant":
# Should have been converted from ChatCompletionMessage to dict
assert "role" in msg
# Verify LLM was called twice (initial + 1 retry)
assert mock_llm_call.call_count == 2
# Test case 2: Test with different raw_response types (Ollama string, dict)
# Test Ollama string response
mock_response_ollama = MagicMock()
mock_response_ollama.response = "I'll help you with that."
mock_response_ollama.tool_calls = None
mock_response_ollama.prompt_tokens = 30
mock_response_ollama.completion_tokens = 15
mock_response_ollama.reasoning = None
mock_response_ollama.raw_response = (
"I'll help you with that." # Ollama returns string
)
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_ollama,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
new_callable=AsyncMock,
return_value=[], # No tools for this test
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Simple prompt",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
)
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
):
outputs[output_name] = output_data
# Should finish since no tool calls
assert "finished" in outputs
assert outputs["finished"] == "I'll help you with that."
# Test case 3: Test with dict raw_response (some providers/tests)
mock_response_dict = MagicMock()
mock_response_dict.response = "Test response"
mock_response_dict.tool_calls = None
mock_response_dict.prompt_tokens = 25
mock_response_dict.completion_tokens = 10
mock_response_dict.reasoning = None
mock_response_dict.raw_response = {
"role": "assistant",
"content": "Test response",
} # Dict format
from unittest.mock import AsyncMock
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response_dict,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
new_callable=AsyncMock,
return_value=[],
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Another test",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
)
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
):
outputs[output_name] = output_data
assert "finished" in outputs
assert outputs["finished"] == "Test response"

View File

@@ -48,24 +48,16 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
assert "parameters" in signature["function"]
assert "properties" in signature["function"]["parameters"]
# Check that dynamic fields are handled with original names
# Check that dynamic fields are handled
properties = signature["function"]["parameters"]["properties"]
assert len(properties) == 3 # Should have all three fields
# Check that field names are cleaned (for Anthropic API compatibility)
assert "values___name" in properties
assert "values___age" in properties
assert "values___city" in properties
# Each dynamic field should have proper schema with descriptive text
for field_name, prop_value in properties.items():
# Each dynamic field should have proper schema
for prop_value in properties.values():
assert "type" in prop_value
assert prop_value["type"] == "string" # Dynamic fields get string type
assert "description" in prop_value
# Check that descriptions properly explain the dynamic field
if field_name == "values___name":
assert "Dictionary field 'name'" in prop_value["description"]
assert "values['name']" in prop_value["description"]
assert "Dynamic value for" in prop_value["description"]
@pytest.mark.asyncio
@@ -104,18 +96,10 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
properties = signature["function"]["parameters"]["properties"]
assert len(properties) == 2 # Should have both list items
# Check that field names are cleaned (for Anthropic API compatibility)
assert "entries___0" in properties
assert "entries___1" in properties
# Each dynamic field should have proper schema with descriptive text
for field_name, prop_value in properties.items():
# Each dynamic field should have proper schema
for prop_value in properties.values():
assert prop_value["type"] == "string"
assert "description" in prop_value
# Check that descriptions properly explain the list field
if field_name == "entries___0":
assert "List item 0" in prop_value["description"]
assert "entries[0]" in prop_value["description"]
assert "Dynamic value for" in prop_value["description"]
@pytest.mark.asyncio

View File

@@ -1,553 +0,0 @@
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
import json
from unittest.mock import AsyncMock, Mock, patch
import pytest
from backend.blocks.data_manipulation import AddToListBlock, CreateDictionaryBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.text import MatchTextPatternBlock
from backend.data.dynamic_fields import get_dynamic_field_description
@pytest.mark.asyncio
async def test_dynamic_field_description_generation():
"""Test that dynamic field descriptions are generated correctly."""
# Test dictionary field description
desc = get_dynamic_field_description("values_#_name")
assert "Dictionary field 'name' for base field 'values'" in desc
assert "values['name']" in desc
# Test list field description
desc = get_dynamic_field_description("items_$_0")
assert "List item 0 for base field 'items'" in desc
assert "items[0]" in desc
# Test object field description
desc = get_dynamic_field_description("user_@_email")
assert "Object attribute 'email' for base field 'user'" in desc
assert "user.email" in desc
# Test regular field fallback
desc = get_dynamic_field_description("regular_field")
assert desc == "Value for regular_field"
@pytest.mark.asyncio
async def test_create_block_function_signature_with_dict_fields():
"""Test that function signatures are created correctly for dictionary dynamic fields."""
block = SmartDecisionMakerBlock()
# Create a mock node for CreateDictionaryBlock
mock_node = Mock()
mock_node.block = CreateDictionaryBlock()
mock_node.block_id = CreateDictionaryBlock().id
mock_node.input_default = {}
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
mock_links = [
Mock(
source_name="tools_^_create_dict_~_values___name", # Sanitized source
sink_name="values_#_name", # Original sink
sink_id="dict_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_values___age", # Sanitized source
sink_name="values_#_age", # Original sink
sink_id="dict_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_create_dict_~_values___email", # Sanitized source
sink_name="values_#_email", # Original sink
sink_id="dict_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
# Verify the signature structure
assert signature["type"] == "function"
assert "function" in signature
assert "parameters" in signature["function"]
assert "properties" in signature["function"]["parameters"]
# Check that dynamic fields are handled with original names
properties = signature["function"]["parameters"]["properties"]
assert len(properties) == 3
# Check cleaned field names (for Anthropic API compatibility)
assert "values___name" in properties
assert "values___age" in properties
assert "values___email" in properties
# Check descriptions mention they are dictionary fields
assert "Dictionary field" in properties["values___name"]["description"]
assert "values['name']" in properties["values___name"]["description"]
assert "Dictionary field" in properties["values___age"]["description"]
assert "values['age']" in properties["values___age"]["description"]
assert "Dictionary field" in properties["values___email"]["description"]
assert "values['email']" in properties["values___email"]["description"]
@pytest.mark.asyncio
async def test_create_block_function_signature_with_list_fields():
"""Test that function signatures are created correctly for list dynamic fields."""
block = SmartDecisionMakerBlock()
# Create a mock node for AddToListBlock
mock_node = Mock()
mock_node.block = AddToListBlock()
mock_node.block_id = AddToListBlock().id
mock_node.input_default = {}
# Create mock links with dynamic list fields
mock_links = [
Mock(
source_name="tools_^_add_list_~_0",
sink_name="entries_$_0", # Dynamic list field
sink_id="list_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_add_list_~_1",
sink_name="entries_$_1", # Dynamic list field
sink_id="list_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_add_list_~_2",
sink_name="entries_$_2", # Dynamic list field
sink_id="list_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
# Verify the signature structure
assert signature["type"] == "function"
properties = signature["function"]["parameters"]["properties"]
# Check cleaned field names (for Anthropic API compatibility)
assert "entries___0" in properties
assert "entries___1" in properties
assert "entries___2" in properties
# Check descriptions mention they are list items
assert "List item 0" in properties["entries___0"]["description"]
assert "entries[0]" in properties["entries___0"]["description"]
assert "List item 1" in properties["entries___1"]["description"]
assert "entries[1]" in properties["entries___1"]["description"]
@pytest.mark.asyncio
async def test_create_block_function_signature_with_object_fields():
"""Test that function signatures are created correctly for object dynamic fields."""
block = SmartDecisionMakerBlock()
# Create a mock node for MatchTextPatternBlock (simulating object fields)
mock_node = Mock()
mock_node.block = MatchTextPatternBlock()
mock_node.block_id = MatchTextPatternBlock().id
mock_node.input_default = {}
# Create mock links with dynamic object fields
mock_links = [
Mock(
source_name="tools_^_extract_~_user_name",
sink_name="data_@_user_name", # Dynamic object field
sink_id="extract_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_extract_~_user_email",
sink_name="data_@_user_email", # Dynamic object field
sink_id="extract_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
# Verify the signature structure
properties = signature["function"]["parameters"]["properties"]
# Check cleaned field names (for Anthropic API compatibility)
assert "data___user_name" in properties
assert "data___user_email" in properties
# Check descriptions mention they are object attributes
assert "Object attribute" in properties["data___user_name"]["description"]
assert "data.user_name" in properties["data___user_name"]["description"]
@pytest.mark.asyncio
async def test_create_function_signature():
"""Test that the mapping between sanitized and original field names is built correctly."""
block = SmartDecisionMakerBlock()
# Mock the database client and connected nodes
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db:
mock_client = AsyncMock()
mock_db.return_value = mock_client
# Create mock nodes and links
mock_dict_node = Mock()
mock_dict_node.block = CreateDictionaryBlock()
mock_dict_node.block_id = CreateDictionaryBlock().id
mock_dict_node.input_default = {}
mock_list_node = Mock()
mock_list_node.block = AddToListBlock()
mock_list_node.block_id = AddToListBlock().id
mock_list_node.input_default = {}
# Mock links with dynamic fields
dict_link1 = Mock(
source_name="tools_^_create_dictionary_~_name",
sink_name="values_#_name",
sink_id="dict_node_id",
source_id="test_node_id",
)
dict_link2 = Mock(
source_name="tools_^_create_dictionary_~_age",
sink_name="values_#_age",
sink_id="dict_node_id",
source_id="test_node_id",
)
list_link = Mock(
source_name="tools_^_add_to_list_~_0",
sink_name="entries_$_0",
sink_id="list_node_id",
source_id="test_node_id",
)
mock_client.get_connected_output_nodes.return_value = [
(dict_link1, mock_dict_node),
(dict_link2, mock_dict_node),
(list_link, mock_list_node),
]
# Call the method that builds signatures
tool_functions = await block._create_function_signature("test_node_id")
# Verify we got 2 tool functions (one for dict, one for list)
assert len(tool_functions) == 2
# Verify the tool functions contain the dynamic field names
dict_tool = next(
(
tool
for tool in tool_functions
if tool["function"]["name"] == "createdictionaryblock"
),
None,
)
assert dict_tool is not None
dict_properties = dict_tool["function"]["parameters"]["properties"]
assert "values___name" in dict_properties
assert "values___age" in dict_properties
list_tool = next(
(
tool
for tool in tool_functions
if tool["function"]["name"] == "addtolistblock"
),
None,
)
assert list_tool is not None
list_properties = list_tool["function"]["parameters"]["properties"]
assert "entries___0" in list_properties
@pytest.mark.asyncio
async def test_output_yielding_with_dynamic_fields():
"""Test that outputs are yielded correctly with dynamic field names mapped back."""
block = SmartDecisionMakerBlock()
# No more sanitized mapping needed since we removed sanitization
# Mock LLM response with tool calls
mock_response = Mock()
mock_response.tool_calls = [
Mock(
function=Mock(
arguments=json.dumps(
{
"values___name": "Alice",
"values___age": 30,
"values___email": "alice@example.com",
}
),
)
)
]
# Ensure function name is a real string, not a Mock name
mock_response.tool_calls[0].function.name = "createdictionaryblock"
mock_response.reasoning = "Creating a dictionary with user information"
mock_response.raw_response = {"role": "assistant", "content": "test"}
mock_response.prompt_tokens = 100
mock_response.completion_tokens = 50
# Mock the LLM call
with patch(
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
) as mock_llm:
mock_llm.return_value = mock_response
# Mock the function signature creation
with patch.object(
block, "_create_function_signature", new_callable=AsyncMock
) as mock_sig:
mock_sig.return_value = [
{
"type": "function",
"function": {
"name": "createdictionaryblock",
"parameters": {
"type": "object",
"properties": {
"values___name": {"type": "string"},
"values___age": {"type": "number"},
"values___email": {"type": "string"},
},
},
},
}
]
# Create input data
from backend.blocks import llm
input_data = block.input_schema(
prompt="Create a user dictionary",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O,
)
# Run the block
outputs = {}
async for output_name, output_value in block.run(
input_data,
credentials=llm.TEST_CREDENTIALS,
graph_id="test_graph",
node_id="test_node",
graph_exec_id="test_exec",
node_exec_id="test_node_exec",
user_id="test_user",
):
outputs[output_name] = output_value
# Verify the outputs use sanitized field names (matching frontend normalizeToolName)
assert "tools_^_createdictionaryblock_~_values___name" in outputs
assert outputs["tools_^_createdictionaryblock_~_values___name"] == "Alice"
assert "tools_^_createdictionaryblock_~_values___age" in outputs
assert outputs["tools_^_createdictionaryblock_~_values___age"] == 30
assert "tools_^_createdictionaryblock_~_values___email" in outputs
assert (
outputs["tools_^_createdictionaryblock_~_values___email"]
== "alice@example.com"
)
@pytest.mark.asyncio
async def test_mixed_regular_and_dynamic_fields():
"""Test handling of blocks with both regular and dynamic fields."""
block = SmartDecisionMakerBlock()
# Create a mock node
mock_node = Mock()
mock_node.block = Mock()
mock_node.block.name = "TestBlock"
mock_node.block.description = "A test block"
mock_node.block.input_schema = Mock()
# Mock the get_field_schema to return a proper schema for regular fields
def get_field_schema(field_name):
if field_name == "regular_field":
return {"type": "string", "description": "A regular field"}
elif field_name == "values":
return {"type": "object", "description": "A dictionary field"}
else:
raise KeyError(f"Field {field_name} not found")
mock_node.block.input_schema.get_field_schema = get_field_schema
mock_node.block.input_schema.jsonschema = Mock(
return_value={"properties": {}, "required": []}
)
# Create links with both regular and dynamic fields
mock_links = [
Mock(
source_name="tools_^_test_~_regular",
sink_name="regular_field", # Regular field
sink_id="test_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_test_~_dict_key",
sink_name="values_#_key1", # Dynamic dict field
sink_id="test_node_id",
source_id="smart_decision_node_id",
),
Mock(
source_name="tools_^_test_~_dict_key2",
sink_name="values_#_key2", # Dynamic dict field
sink_id="test_node_id",
source_id="smart_decision_node_id",
),
]
# Generate function signature
signature = await block._create_block_function_signature(mock_node, mock_links) # type: ignore
# Check properties
properties = signature["function"]["parameters"]["properties"]
assert len(properties) == 3
# Regular field should have its original schema
assert "regular_field" in properties
assert properties["regular_field"]["description"] == "A regular field"
# Dynamic fields should have generated descriptions
assert "values___key1" in properties
assert "Dictionary field" in properties["values___key1"]["description"]
assert "values___key2" in properties
assert "Dictionary field" in properties["values___key2"]["description"]
@pytest.mark.asyncio
async def test_validation_errors_dont_pollute_conversation():
"""Test that validation errors are only used during retries and don't pollute the conversation."""
block = SmartDecisionMakerBlock()
# Track conversation history changes
conversation_snapshots = []
# Mock response with invalid tool call (missing required parameter)
invalid_response = Mock()
invalid_response.tool_calls = [
Mock(
function=Mock(
arguments=json.dumps({"wrong_param": "value"}), # Wrong parameter name
)
)
]
# Ensure function name is a real string, not a Mock name
invalid_response.tool_calls[0].function.name = "test_tool"
invalid_response.reasoning = None
invalid_response.raw_response = {"role": "assistant", "content": "invalid"}
invalid_response.prompt_tokens = 100
invalid_response.completion_tokens = 50
# Mock valid response after retry
valid_response = Mock()
valid_response.tool_calls = [
Mock(function=Mock(arguments=json.dumps({"correct_param": "value"})))
]
# Ensure function name is a real string, not a Mock name
valid_response.tool_calls[0].function.name = "test_tool"
valid_response.reasoning = None
valid_response.raw_response = {"role": "assistant", "content": "valid"}
valid_response.prompt_tokens = 100
valid_response.completion_tokens = 50
call_count = 0
async def mock_llm_call(**kwargs):
nonlocal call_count
# Capture conversation state
conversation_snapshots.append(kwargs.get("prompt", []).copy())
call_count += 1
if call_count == 1:
return invalid_response
else:
return valid_response
# Mock the LLM call
with patch(
"backend.blocks.smart_decision_maker.llm.llm_call", new_callable=AsyncMock
) as mock_llm:
mock_llm.side_effect = mock_llm_call
# Mock the function signature creation
with patch.object(
block, "_create_function_signature", new_callable=AsyncMock
) as mock_sig:
mock_sig.return_value = [
{
"type": "function",
"function": {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"correct_param": {
"type": "string",
"description": "The correct parameter",
}
},
"required": ["correct_param"],
},
},
}
]
# Create input data
from backend.blocks import llm
input_data = block.input_schema(
prompt="Test prompt",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O,
retry=3, # Allow retries
)
# Run the block
outputs = {}
async for output_name, output_value in block.run(
input_data,
credentials=llm.TEST_CREDENTIALS,
graph_id="test_graph",
node_id="test_node",
graph_exec_id="test_exec",
node_exec_id="test_node_exec",
user_id="test_user",
):
outputs[output_name] = output_value
# Verify we had 2 LLM calls (initial + retry)
assert call_count == 2
# Check the final conversation output
final_conversation = outputs.get("conversations", [])
# The final conversation should NOT contain the validation error message
error_messages = [
msg
for msg in final_conversation
if msg.get("role") == "user"
and "parameter errors" in msg.get("content", "")
]
assert (
len(error_messages) == 0
), "Validation error leaked into final conversation"
# The final conversation should only have the successful response
assert final_conversation[-1]["content"] == "valid"

View File

@@ -1,131 +0,0 @@
import pytest
from backend.blocks.io import AgentTableInputBlock
from backend.util.test import execute_block_test
@pytest.mark.asyncio
async def test_table_input_block():
"""Test the AgentTableInputBlock with basic input/output."""
block = AgentTableInputBlock()
await execute_block_test(block)
@pytest.mark.asyncio
async def test_table_input_with_data():
"""Test AgentTableInputBlock with actual table data."""
block = AgentTableInputBlock()
input_data = block.Input(
name="test_table",
column_headers=["Name", "Age", "City"],
value=[
{"Name": "John", "Age": "30", "City": "New York"},
{"Name": "Jane", "Age": "25", "City": "London"},
{"Name": "Bob", "Age": "35", "City": "Paris"},
],
)
output_data = []
async for output_name, output_value in block.run(input_data):
output_data.append((output_name, output_value))
assert len(output_data) == 1
assert output_data[0][0] == "result"
result = output_data[0][1]
assert len(result) == 3
assert result[0]["Name"] == "John"
assert result[1]["Age"] == "25"
assert result[2]["City"] == "Paris"
@pytest.mark.asyncio
async def test_table_input_empty_data():
"""Test AgentTableInputBlock with empty data."""
block = AgentTableInputBlock()
input_data = block.Input(
name="empty_table", column_headers=["Col1", "Col2"], value=[]
)
output_data = []
async for output_name, output_value in block.run(input_data):
output_data.append((output_name, output_value))
assert len(output_data) == 1
assert output_data[0][0] == "result"
assert output_data[0][1] == []
@pytest.mark.asyncio
async def test_table_input_with_missing_columns():
"""Test AgentTableInputBlock passes through data with missing columns as-is."""
block = AgentTableInputBlock()
input_data = block.Input(
name="partial_table",
column_headers=["Name", "Age", "City"],
value=[
{"Name": "John", "Age": "30"}, # Missing City
{"Name": "Jane", "City": "London"}, # Missing Age
{"Age": "35", "City": "Paris"}, # Missing Name
],
)
output_data = []
async for output_name, output_value in block.run(input_data):
output_data.append((output_name, output_value))
result = output_data[0][1]
assert len(result) == 3
# Check data is passed through as-is
assert result[0] == {"Name": "John", "Age": "30"}
assert result[1] == {"Name": "Jane", "City": "London"}
assert result[2] == {"Age": "35", "City": "Paris"}
@pytest.mark.asyncio
async def test_table_input_none_value():
"""Test AgentTableInputBlock with None value returns empty list."""
block = AgentTableInputBlock()
input_data = block.Input(
name="none_table", column_headers=["Name", "Age"], value=None
)
output_data = []
async for output_name, output_value in block.run(input_data):
output_data.append((output_name, output_value))
assert len(output_data) == 1
assert output_data[0][0] == "result"
assert output_data[0][1] == []
@pytest.mark.asyncio
async def test_table_input_with_default_headers():
"""Test AgentTableInputBlock with default column headers."""
block = AgentTableInputBlock()
# Don't specify column_headers, should use defaults
input_data = block.Input(
name="default_headers_table",
value=[
{"Column 1": "A", "Column 2": "B", "Column 3": "C"},
{"Column 1": "D", "Column 2": "E", "Column 3": "F"},
],
)
output_data = []
async for output_name, output_value in block.run(input_data):
output_data.append((output_name, output_value))
assert len(output_data) == 1
assert output_data[0][0] == "result"
result = output_data[0][1]
assert len(result) == 2
assert result[0]["Column 1"] == "A"
assert result[1]["Column 3"] == "F"

View File

@@ -2,8 +2,6 @@ import re
from pathlib import Path
from typing import Any
import regex # Has built-in timeout support
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util import json, text
@@ -139,11 +137,6 @@ class ExtractTextInformationBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Security fix: Add limits to prevent ReDoS and memory exhaustion
MAX_TEXT_LENGTH = 1_000_000 # 1MB character limit
MAX_MATCHES = 1000 # Maximum number of matches to prevent memory exhaustion
MAX_MATCH_LENGTH = 10_000 # Maximum length per match
flags = 0
if not input_data.case_sensitive:
flags = flags | re.IGNORECASE
@@ -155,85 +148,20 @@ class ExtractTextInformationBlock(Block):
else:
txt = json.dumps(input_data.text)
# Limit text size to prevent DoS
if len(txt) > MAX_TEXT_LENGTH:
txt = txt[:MAX_TEXT_LENGTH]
# Validate regex pattern to prevent dangerous patterns
dangerous_patterns = [
r".*\+.*\+", # Nested quantifiers
r".*\*.*\*", # Nested quantifiers
r"(?=.*\+)", # Lookahead with quantifier
r"(?=.*\*)", # Lookahead with quantifier
r"\(.+\)\+", # Group with nested quantifier
r"\(.+\)\*", # Group with nested quantifier
r"\([^)]+\+\)\+", # Nested quantifiers like (a+)+
r"\([^)]+\*\)\*", # Nested quantifiers like (a*)*
matches = [
match.group(input_data.group)
for match in re.finditer(input_data.pattern, txt, flags)
if input_data.group <= len(match.groups())
]
# Check if pattern is potentially dangerous
is_dangerous = any(
re.search(dangerous, input_data.pattern) for dangerous in dangerous_patterns
)
# Use regex module with timeout for dangerous patterns
# For safe patterns, use standard re module for compatibility
try:
matches = []
match_count = 0
if is_dangerous:
# Use regex module with timeout (5 seconds) for dangerous patterns
# The regex module supports timeout parameter in finditer
try:
for match in regex.finditer(
input_data.pattern, txt, flags=flags, timeout=5.0
):
if match_count >= MAX_MATCHES:
break
if input_data.group <= len(match.groups()):
match_text = match.group(input_data.group)
# Limit match length to prevent memory exhaustion
if len(match_text) > MAX_MATCH_LENGTH:
match_text = match_text[:MAX_MATCH_LENGTH]
matches.append(match_text)
match_count += 1
except regex.error as e:
# Timeout occurred or regex error
if "timeout" in str(e).lower():
# Timeout - return empty results
pass
else:
# Other regex error
raise
else:
# Use standard re module for non-dangerous patterns
for match in re.finditer(input_data.pattern, txt, flags):
if match_count >= MAX_MATCHES:
break
if input_data.group <= len(match.groups()):
match_text = match.group(input_data.group)
# Limit match length to prevent memory exhaustion
if len(match_text) > MAX_MATCH_LENGTH:
match_text = match_text[:MAX_MATCH_LENGTH]
matches.append(match_text)
match_count += 1
if not input_data.find_all:
matches = matches[:1]
for match in matches:
yield "positive", match
if not matches:
yield "negative", input_data.text
yield "matched_results", matches
yield "matched_count", len(matches)
except Exception:
# Return empty results on any regex error
if not input_data.find_all:
matches = matches[:1]
for match in matches:
yield "positive", match
if not matches:
yield "negative", input_data.text
yield "matched_results", []
yield "matched_count", 0
yield "matched_results", matches
yield "matched_count", len(matches)
class FillTextTemplateBlock(Block):
@@ -244,11 +172,6 @@ class FillTextTemplateBlock(Block):
format: str = SchemaField(
description="Template to format the text using `values`. Use Jinja2 syntax."
)
escape_html: bool = SchemaField(
default=False,
advanced=True,
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
)
class Output(BlockSchema):
output: str = SchemaField(description="Formatted text")
@@ -282,7 +205,6 @@ class FillTextTemplateBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
formatter = text.TextFormatter(autoescape=input_data.escape_html)
yield "output", formatter.format_string(input_data.format, input_data.values)

View File

@@ -270,17 +270,13 @@ class GetCurrentDateBlock(Block):
test_output=[
(
"date",
lambda t: abs(
datetime.now().date() - datetime.strptime(t, "%Y-%m-%d").date()
)
<= timedelta(days=8), # 7 days difference + 1 day error margin.
lambda t: abs(datetime.now() - datetime.strptime(t, "%Y-%m-%d"))
< timedelta(days=8), # 7 days difference + 1 day error margin.
),
(
"date",
lambda t: abs(
datetime.now().date() - datetime.strptime(t, "%m/%d/%Y").date()
)
<= timedelta(days=8),
lambda t: abs(datetime.now() - datetime.strptime(t, "%m/%d/%Y"))
< timedelta(days=8),
# 7 days difference + 1 day error margin.
),
(
@@ -386,7 +382,7 @@ class GetCurrentDateAndTimeBlock(Block):
lambda t: abs(
datetime.now().date() - datetime.strptime(t, "%Y/%m/%d").date()
)
<= timedelta(days=1), # Date format only, no time component
< timedelta(days=1), # Date format only, no time component
),
(
"date_time",

View File

@@ -26,14 +26,6 @@ class XMLParserBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Security fix: Add size limits to prevent XML bomb attacks
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
if len(input_data.input_xml) > MAX_XML_SIZE:
raise ValueError(
f"XML too large: {len(input_data.input_xml)} bytes > {MAX_XML_SIZE} bytes"
)
try:
tokens = tokenize(input_data.input_xml)
parser = Parser(tokens)

View File

@@ -1,7 +1,6 @@
from urllib.parse import parse_qs, urlparse
from youtube_transcript_api._api import YouTubeTranscriptApi
from youtube_transcript_api._errors import NoTranscriptFound
from youtube_transcript_api._transcripts import FetchedTranscript
from youtube_transcript_api.formatters import TextFormatter
@@ -65,29 +64,7 @@ class TranscribeYoutubeVideoBlock(Block):
@staticmethod
def get_transcript(video_id: str) -> FetchedTranscript:
"""
Get transcript for a video, preferring English but falling back to any available language.
:param video_id: The YouTube video ID
:return: The fetched transcript
:raises: Any exception except NoTranscriptFound for requested languages
"""
api = YouTubeTranscriptApi()
try:
# Try to get English transcript first (default behavior)
return api.fetch(video_id=video_id)
except NoTranscriptFound:
# If English is not available, get the first available transcript
transcript_list = api.list(video_id)
# Try manually created transcripts first, then generated ones
available_transcripts = list(
transcript_list._manually_created_transcripts.values()
) + list(transcript_list._generated_transcripts.values())
if available_transcripts:
# Fetch the first available transcript
return available_transcripts[0].fetch()
# If no transcripts at all, re-raise the original error
raise
return YouTubeTranscriptApi().fetch(video_id=video_id)
@staticmethod
def format_transcript(transcript: FetchedTranscript) -> str:

View File

@@ -45,6 +45,9 @@ class MainApp(AppProcess):
app.main(silent=True)
def cleanup(self):
pass
@click.group()
def main():

View File

@@ -9,7 +9,6 @@ from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyWhereUniqueInput
from pydantic import BaseModel, Field
from backend.data.includes import MAX_USER_API_KEYS_FETCH
from backend.util.exceptions import NotAuthorizedError, NotFoundError
logger = logging.getLogger(__name__)
@@ -179,13 +178,9 @@ async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
return APIKeyInfo.from_db(updated_api_key)
async def list_user_api_keys(
user_id: str, limit: int = MAX_USER_API_KEYS_FETCH
) -> list[APIKeyInfo]:
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
api_keys = await PrismaAPIKey.prisma().find_many(
where={"userId": user_id},
order={"createdAt": "desc"},
take=limit,
where={"userId": user_id}, order={"createdAt": "desc"}
)
return [APIKeyInfo.from_db(key) for key in api_keys]

View File

@@ -20,6 +20,7 @@ from typing import (
import jsonref
import jsonschema
from autogpt_libs.utils.cache import cached
from prisma.models import AgentBlock
from prisma.types import AgentBlockCreateInput
from pydantic import BaseModel
@@ -27,7 +28,6 @@ from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.cache import cached
from backend.util.settings import Config
from .model import (
@@ -722,7 +722,7 @@ def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
return cls() if cls else None
@cached(ttl_seconds=3600)
@cached()
def get_webhook_block_ids() -> Sequence[str]:
return [
id
@@ -731,7 +731,7 @@ def get_webhook_block_ids() -> Sequence[str]:
]
@cached(ttl_seconds=3600)
@cached()
def get_io_block_ids() -> Sequence[str]:
return [
id

View File

@@ -1,11 +1,7 @@
from typing import Type
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
from backend.blocks.ai_shortform_video_block import (
AIAdMakerVideoCreatorBlock,
AIScreenshotToVideoAdBlock,
AIShortformVideoCreatorBlock,
)
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
from backend.blocks.apollo.organization import SearchOrganizationsBlock
from backend.blocks.apollo.people import SearchPeopleBlock
from backend.blocks.apollo.person import GetPersonDetailBlock
@@ -73,9 +69,9 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_5_SONNET: 4,
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
LlmModel.CLAUDE_3_HAIKU: 1,
LlmModel.AIML_API_QWEN2_5_72B: 1,
LlmModel.AIML_API_LLAMA3_1_70B: 1,
@@ -325,31 +321,7 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
],
AIShortformVideoCreatorBlock: [
BlockCost(
cost_amount=307,
cost_filter={
"credentials": {
"id": revid_credentials.id,
"provider": revid_credentials.provider,
"type": revid_credentials.type,
}
},
)
],
AIAdMakerVideoCreatorBlock: [
BlockCost(
cost_amount=714,
cost_filter={
"credentials": {
"id": revid_credentials.id,
"provider": revid_credentials.provider,
"type": revid_credentials.type,
}
},
)
],
AIScreenshotToVideoAdBlock: [
BlockCost(
cost_amount=612,
cost_amount=50,
cost_filter={
"credentials": {
"id": revid_credentials.id,

View File

@@ -5,6 +5,7 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast
import stripe
from prisma import Json
from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
@@ -12,13 +13,16 @@ from prisma.enums import (
OnboardingStep,
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
from prisma.models import CreditRefundRequest, CreditTransaction, User
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
CreditTransactionWhereInput,
)
from pydantic import BaseModel
from backend.data import db
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.db import query_raw_with_schema
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
from backend.data.model import (
AutoTopUpConfig,
RefundRequest,
@@ -31,8 +35,7 @@ from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.json import SafeJson, dumps
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.retry import func_retry
from backend.util.settings import Settings
@@ -45,10 +48,6 @@ stripe.api_key = settings.secrets.stripe_api_key
logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
# Constants for test compatibility
POSTGRES_INT_MAX = 2147483647
POSTGRES_INT_MIN = -2147483648
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
@@ -139,20 +138,14 @@ class UserCreditBase(ABC):
pass
@abstractmethod
async def onboarding_reward(
self, user_id: str, credits: int, step: OnboardingStep
) -> bool:
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
"""
Reward the user with credits for completing an onboarding step.
Won't reward if the user has already received credits for the step.
Args:
user_id (str): The user ID.
credits (int): The amount to reward.
step (OnboardingStep): The onboarding step.
Returns:
bool: True if rewarded, False if already rewarded.
"""
pass
@@ -242,12 +235,6 @@ class UserCreditBase(ABC):
"""
Returns the current balance of the user & the latest balance snapshot time.
"""
# Check UserBalance first for efficiency and consistency
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
if user_balance:
return user_balance.balance, user_balance.updatedAt
# Fallback to transaction history computation if UserBalance doesn't exist
top_time = self.time_now()
snapshot = await CreditTransaction.prisma().find_first(
where={
@@ -262,86 +249,72 @@ class UserCreditBase(ABC):
snapshot_balance = snapshot.runningBalance or 0 if snapshot else 0
snapshot_time = snapshot.createdAt if snapshot else datetime_min
return snapshot_balance, snapshot_time
# Get transactions after the snapshot, this should not exist, but just in case.
transactions = await CreditTransaction.prisma().group_by(
by=["userId"],
sum={"amount": True},
max={"createdAt": True},
where={
"userId": user_id,
"createdAt": {
"gt": snapshot_time,
"lte": top_time,
},
"isActive": True,
},
)
transaction_balance = (
int(transactions[0].get("_sum", {}).get("amount", 0) + snapshot_balance)
if transactions
else snapshot_balance
)
transaction_time = (
datetime.fromisoformat(
str(transactions[0].get("_max", {}).get("createdAt", datetime_min))
)
if transactions
else snapshot_time
)
return transaction_balance, transaction_time
@func_retry
async def _enable_transaction(
self,
transaction_key: str,
user_id: str,
metadata: SafeJson,
metadata: Json,
new_transaction_key: str | None = None,
):
# First check if transaction exists and is inactive (safety check)
transaction = await CreditTransaction.prisma().find_first(
where={
"transactionKey": transaction_key,
"userId": user_id,
"isActive": False,
}
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
)
if not transaction:
# Transaction doesn't exist or is already active, return early
return None
if transaction.isActive:
return
# Atomic operation to enable transaction and update user balance using UserBalance
result = await query_raw_with_schema(
"""
WITH user_balance_lock AS (
SELECT
$2::text as userId,
COALESCE(
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $2 FOR UPDATE),
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
(SELECT COALESCE(ct."runningBalance", 0)
FROM {schema_prefix}"CreditTransaction" ct
WHERE ct."userId" = $2
AND ct."isActive" = true
AND ct."runningBalance" IS NOT NULL
ORDER BY ct."createdAt" DESC
LIMIT 1),
0
) as balance
),
transaction_check AS (
SELECT * FROM {schema_prefix}"CreditTransaction"
WHERE "transactionKey" = $1 AND "userId" = $2 AND "isActive" = false
),
balance_update AS (
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
SELECT
$2::text,
user_balance_lock.balance + transaction_check.amount,
CURRENT_TIMESTAMP
FROM user_balance_lock, transaction_check
ON CONFLICT ("userId") DO UPDATE SET
"balance" = EXCLUDED."balance",
"updatedAt" = EXCLUDED."updatedAt"
RETURNING "balance", "updatedAt"
),
transaction_update AS (
UPDATE {schema_prefix}"CreditTransaction"
SET "transactionKey" = COALESCE($4, $1),
"isActive" = true,
"runningBalance" = balance_update.balance,
"createdAt" = balance_update."updatedAt",
"metadata" = $3::jsonb
FROM balance_update, transaction_check
WHERE {schema_prefix}"CreditTransaction"."transactionKey" = transaction_check."transactionKey"
AND {schema_prefix}"CreditTransaction"."userId" = transaction_check."userId"
RETURNING {schema_prefix}"CreditTransaction"."runningBalance"
async with db.locked_transaction(f"usr_trx_{user_id}"):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
)
SELECT "runningBalance" as balance FROM transaction_update;
""",
transaction_key, # $1
user_id, # $2
dumps(metadata.data), # $3 - use pre-serialized JSON string for JSONB
new_transaction_key, # $4
)
if transaction.isActive:
return
if result:
# UserBalance is already updated by the CTE
return result[0]["balance"]
user_balance, _ = await self._get_credits(user_id)
await CreditTransaction.prisma().update(
where={
"creditTransactionIdentifier": {
"transactionKey": transaction_key,
"userId": user_id,
}
},
data={
"transactionKey": new_transaction_key or transaction_key,
"isActive": True,
"runningBalance": user_balance + transaction.amount,
"createdAt": self.time_now(),
"metadata": metadata,
},
)
async def _add_transaction(
self,
@@ -352,54 +325,12 @@ class UserCreditBase(ABC):
transaction_key: str | None = None,
ceiling_balance: int | None = None,
fail_insufficient_credits: bool = True,
metadata: SafeJson = SafeJson({}),
metadata: Json = SafeJson({}),
) -> tuple[int, str]:
"""
Add a new transaction for the user.
This is the only method that should be used to add a new transaction.
ATOMIC OPERATION DESIGN DECISION:
================================
This method uses PostgreSQL row-level locking (FOR UPDATE) for atomic credit operations.
After extensive analysis of concurrency patterns and correctness requirements, we determined
that the FOR UPDATE approach is necessary despite the latency overhead.
WHY FOR UPDATE LOCKING IS REQUIRED:
----------------------------------
1. **Data Consistency**: Credit operations must be ACID-compliant. The balance check,
calculation, and update must be atomic to prevent race conditions where:
- Multiple spend operations could exceed available balance
- Lost update problems could occur with concurrent top-ups
- Refunds could create negative balances incorrectly
2. **Serializability**: FOR UPDATE ensures operations are serialized at the database level,
guaranteeing that each transaction sees a consistent view of the balance before applying changes.
3. **Correctness Over Performance**: Financial operations require absolute correctness.
The ~10-50ms latency increase from row locking is acceptable for the guarantee that
no user will ever have an incorrect balance due to race conditions.
4. **PostgreSQL Optimization**: Modern PostgreSQL versions optimize row locks efficiently.
The performance cost is minimal compared to the complexity and risk of lock-free approaches.
ALTERNATIVES CONSIDERED AND REJECTED:
------------------------------------
- **Optimistic Concurrency**: Using version numbers or timestamps would require complex
retry logic and could still fail under high contention scenarios.
- **Application-Level Locking**: Redis locks or similar would add network overhead and
single points of failure while being less reliable than database locks.
- **Event Sourcing**: Would require complete architectural changes and eventual consistency
models that don't fit our real-time balance requirements.
PERFORMANCE CHARACTERISTICS:
---------------------------
- Single user operations: 10-50ms latency (acceptable for financial operations)
- Concurrent operations on same user: Serialized (prevents data corruption)
- Concurrent operations on different users: Fully parallel (no blocking)
This design prioritizes correctness and data integrity over raw performance,
which is the appropriate choice for a credit/payment system.
Args:
user_id (str): The user ID.
amount (int): The amount of credits to add.
@@ -413,142 +344,40 @@ class UserCreditBase(ABC):
Returns:
tuple[int, str]: The new balance & the transaction key.
"""
# Quick validation for ceiling balance to avoid unnecessary database operations
if ceiling_balance and amount > 0:
current_balance, _ = await self._get_credits(user_id)
if current_balance >= ceiling_balance:
async with db.locked_transaction(f"usr_trx_{user_id}"):
# Get latest balance snapshot
user_balance, _ = await self._get_credits(user_id)
if ceiling_balance and amount > 0 and user_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
f"You already have enough balance of ${user_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
)
# Single unified atomic operation for all transaction types using UserBalance
try:
result = await query_raw_with_schema(
"""
WITH user_balance_lock AS (
SELECT
$1::text as userId,
-- CRITICAL: FOR UPDATE lock prevents concurrent modifications to the same user's balance
-- This ensures atomic read-modify-write operations and prevents race conditions
COALESCE(
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $1 FOR UPDATE),
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
(SELECT COALESCE(ct."runningBalance", 0)
FROM {schema_prefix}"CreditTransaction" ct
WHERE ct."userId" = $1
AND ct."isActive" = true
AND ct."runningBalance" IS NOT NULL
ORDER BY ct."createdAt" DESC
LIMIT 1),
0
) as balance
),
balance_update AS (
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
SELECT
$1::text,
CASE
-- For inactive transactions: Don't update balance
WHEN $5::boolean = false THEN user_balance_lock.balance
-- For ceiling balance (amount > 0): Apply ceiling
WHEN $2 > 0 AND $7::int IS NOT NULL AND user_balance_lock.balance > $7::int - $2 THEN $7::int
-- For regular operations: Apply with overflow/underflow protection
WHEN user_balance_lock.balance + $2 > $6::int THEN $6::int
WHEN user_balance_lock.balance + $2 < $10::int THEN $10::int
ELSE user_balance_lock.balance + $2
END,
CURRENT_TIMESTAMP
FROM user_balance_lock
WHERE (
$5::boolean = false OR -- Allow inactive transactions
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
$8::boolean = false OR -- Allow when insufficient balance check is disabled
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
if amount < 0 and user_balance + amount < 0:
if fail_insufficient_credits:
raise InsufficientBalanceError(
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=user_balance,
amount=amount,
)
ON CONFLICT ("userId") DO UPDATE SET
"balance" = EXCLUDED."balance",
"updatedAt" = EXCLUDED."updatedAt"
RETURNING "balance", "updatedAt"
),
transaction_insert AS (
INSERT INTO {schema_prefix}"CreditTransaction" (
"userId", "amount", "type", "runningBalance",
"metadata", "isActive", "createdAt", "transactionKey"
)
SELECT
$1::text,
$2::int,
$3::text::{schema_prefix}"CreditTransactionType",
CASE
-- For inactive transactions: Set runningBalance to original balance (don't apply the change yet)
WHEN $5::boolean = false THEN user_balance_lock.balance
ELSE COALESCE(balance_update.balance, user_balance_lock.balance)
END,
$4::jsonb,
$5::boolean,
COALESCE(balance_update."updatedAt", CURRENT_TIMESTAMP),
COALESCE($9, gen_random_uuid()::text)
FROM user_balance_lock
LEFT JOIN balance_update ON true
WHERE (
$5::boolean = false OR -- Allow inactive transactions
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
$8::boolean = false OR -- Allow when insufficient balance check is disabled
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
)
RETURNING "runningBalance", "transactionKey"
)
SELECT "runningBalance" as balance, "transactionKey" FROM transaction_insert;
""",
user_id, # $1
amount, # $2
transaction_type.value, # $3
dumps(metadata.data), # $4 - use pre-serialized JSON string for JSONB
is_active, # $5
POSTGRES_INT_MAX, # $6 - overflow protection
ceiling_balance, # $7 - ceiling balance (nullable)
fail_insufficient_credits, # $8 - check balance for spending
transaction_key, # $9 - transaction key (nullable)
POSTGRES_INT_MIN, # $10 - underflow protection
)
except Exception as e:
# Convert raw SQL unique constraint violations to UniqueViolationError
# for consistent exception handling throughout the codebase
error_str = str(e).lower()
if (
"already exists" in error_str
or "duplicate key" in error_str
or "unique constraint" in error_str
):
# Extract table and constraint info for better error messages
# Re-raise as a UniqueViolationError but with proper format
# Create a minimal data structure that the error constructor expects
raise UniqueViolationError({"error": str(e), "user_facing_error": {}})
# For any other error, re-raise as-is
raise
if result:
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
# UserBalance is already updated by the CTE
return new_balance, tx_key
amount = min(-user_balance, 0)
# If no result, either user doesn't exist or insufficient balance
user = await User.prisma().find_unique(where={"id": user_id})
if not user:
raise ValueError(f"User {user_id} not found")
# Must be insufficient balance for spending operation
if amount < 0 and fail_insufficient_credits:
current_balance, _ = await self._get_credits(user_id)
raise InsufficientBalanceError(
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=current_balance,
amount=amount,
)
# Unexpected case
raise ValueError(f"Transaction failed for user {user_id}, amount {amount}")
# Create the transaction
transaction_data: CreditTransactionCreateInput = {
"userId": user_id,
"amount": amount,
"runningBalance": user_balance + amount,
"type": transaction_type,
"metadata": metadata,
"isActive": is_active,
"createdAt": self.time_now(),
}
if transaction_key:
transaction_data["transactionKey"] = transaction_key
tx = await CreditTransaction.prisma().create(data=transaction_data)
return user_balance + amount, tx.transactionKey
class UserCredit(UserCreditBase):
@@ -623,10 +452,9 @@ class UserCredit(UserCreditBase):
{"reason": f"Reward for completing {step.value} onboarding step."}
),
)
return True
except UniqueViolationError:
# User already received this reward
return False
# Already rewarded for this step
pass
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
@@ -815,7 +643,7 @@ class UserCredit(UserCreditBase):
):
# init metadata, without sharing it with the world
metadata = metadata or {}
if not metadata.get("reason"):
if not metadata["reason"]:
match top_up_type:
case TopUpType.MANUAL:
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
@@ -1077,9 +905,7 @@ class UserCredit(UserCreditBase):
),
)
async def get_refund_requests(
self, user_id: str, limit: int = MAX_CREDIT_REFUND_REQUESTS_FETCH
) -> list[RefundRequest]:
async def get_refund_requests(self, user_id: str) -> list[RefundRequest]:
return [
RefundRequest(
id=r.id,
@@ -1095,7 +921,6 @@ class UserCredit(UserCreditBase):
for r in await CreditRefundRequest.prisma().find_many(
where={"userId": user_id},
order={"createdAt": "desc"},
take=limit,
)
]
@@ -1145,8 +970,8 @@ class DisabledUserCredit(UserCreditBase):
async def top_up_credits(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs) -> bool:
return True
async def onboarding_reward(self, *args, **kwargs):
pass
async def top_up_intent(self, *args, **kwargs) -> str:
return ""
@@ -1164,32 +989,15 @@ class DisabledUserCredit(UserCreditBase):
pass
async def get_user_credit_model(user_id: str) -> UserCreditBase:
"""
Get the credit model for a user, considering LaunchDarkly flags.
Args:
user_id (str): The user ID to check flags for.
Returns:
UserCreditBase: The appropriate credit model for the user
"""
def get_user_credit_model() -> UserCreditBase:
if not settings.config.enable_credit:
return DisabledUserCredit()
# Check LaunchDarkly flag for payment pilot users
# Default to False (beta monthly credit behavior) to maintain current behavior
is_payment_enabled = await is_feature_enabled(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
if is_payment_enabled:
# Payment enabled users get UserCredit (no monthly refills, enable payments)
return UserCredit()
else:
# Default behavior: users get beta monthly credits
if settings.config.enable_beta_monthly_credit:
return BetaUserCredit(settings.config.num_user_credits_refill)
return UserCredit()
def get_block_costs() -> dict[str, list["BlockCost"]]:
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
@@ -1278,8 +1086,7 @@ async def admin_get_user_history(
)
reason = metadata.get("reason", "No reason provided")
user_credit_model = await get_user_credit_model(tx.userId)
balance, _ = await user_credit_model._get_credits(tx.userId)
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
history.append(
UserTransaction(

View File

@@ -1,172 +0,0 @@
"""
Test ceiling balance functionality to ensure auto top-up limits work correctly.
This test was added to cover a previously untested code path that could lead to
incorrect balance capping behavior.
"""
from uuid import uuid4
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
async def create_test_user(user_id: str) -> None:
"""Create a test user for ceiling tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their transactions."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_ceiling_balance_rejects_when_above_threshold(server: SpinTestServer):
"""Test that ceiling balance correctly rejects top-ups when balance is above threshold."""
credit_system = UserCredit()
user_id = f"ceiling-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user balance of 1000 ($10) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
current_balance = await credit_system.get_credits(user_id)
assert current_balance == 1000
# Try to add 200 more with ceiling of 800 (should reject since 1000 > 800)
with pytest.raises(ValueError, match="You already have enough balance"):
await credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
ceiling_balance=800, # Ceiling lower than current balance
)
# Balance should remain unchanged
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 1000, f"Balance should remain 1000, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_ceiling_balance_clamps_when_would_exceed(server: SpinTestServer):
"""Test that ceiling balance correctly clamps amounts that would exceed the ceiling."""
credit_system = UserCredit()
user_id = f"ceiling-clamp-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user balance of 500 ($5) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Add 800 more with ceiling of 1000 (should clamp to 1000, not reach 1300)
final_balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=800,
transaction_type=CreditTransactionType.TOP_UP,
ceiling_balance=1000, # Ceiling should clamp 500 + 800 = 1300 to 1000
)
# Balance should be clamped to ceiling
assert (
final_balance == 1000
), f"Balance should be clamped to 1000, got {final_balance}"
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == 1000
), f"Stored balance should be 1000, got {stored_balance}"
# Verify transaction shows the clamped amount
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": CreditTransactionType.TOP_UP},
order={"createdAt": "desc"},
)
# Should have 2 transactions: 500 + (500 to reach ceiling of 1000)
assert len(transactions) == 2
# The second transaction should show it only added 500, not 800
second_tx = transactions[0] # Most recent
assert second_tx.runningBalance == 1000
# The actual amount recorded could be 800 (what was requested) but balance was clamped
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_ceiling_balance_allows_when_under_threshold(server: SpinTestServer):
"""Test that ceiling balance allows top-ups when balance is under threshold."""
credit_system = UserCredit()
user_id = f"ceiling-under-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user balance of 300 ($3) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=300,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Add 200 more with ceiling of 1000 (should succeed: 300 + 200 = 500 < 1000)
final_balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
ceiling_balance=1000,
)
# Balance should be exactly 500
assert final_balance == 500, f"Balance should be 500, got {final_balance}"
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == 500
), f"Stored balance should be 500, got {stored_balance}"
finally:
await cleanup_test_user(user_id)

View File

@@ -1,737 +0,0 @@
"""
Concurrency and atomicity tests for the credit system.
These tests ensure the credit system handles high-concurrency scenarios correctly
without race conditions, deadlocks, or inconsistent state.
"""
import asyncio
import random
from uuid import uuid4
import prisma.enums
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
# Test with both UserCredit and BetaUserCredit if needed
credit_system = UserCredit()
async def create_test_user(user_id: str) -> None:
"""Create a test user with initial balance."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
# Ensure UserBalance record exists
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their transactions."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_spends_same_user(server: SpinTestServer):
"""Test multiple concurrent spends from the same user don't cause race conditions."""
user_id = f"concurrent-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user initial balance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Try to spend 10 x $1 concurrently
async def spend_one_dollar(idx: int):
try:
return await credit_system.spend_credits(
user_id,
100, # $1
UsageTransactionMetadata(
graph_exec_id=f"concurrent-{idx}",
reason=f"Concurrent spend {idx}",
),
)
except InsufficientBalanceError:
return None
# Run 10 concurrent spends
results = await asyncio.gather(
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
)
# Count successful spends
successful = [
r for r in results if r is not None and not isinstance(r, Exception)
]
failed = [r for r in results if isinstance(r, InsufficientBalanceError)]
# All 10 should succeed since we have exactly $10
assert len(successful) == 10, f"Expected 10 successful, got {len(successful)}"
assert len(failed) == 0, f"Expected 0 failures, got {len(failed)}"
# Final balance should be exactly 0
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
# Verify transaction history is consistent
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE}
)
assert (
len(transactions) == 10
), f"Expected 10 transactions, got {len(transactions)}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_spends_insufficient_balance(server: SpinTestServer):
"""Test that concurrent spends correctly enforce balance limits."""
user_id = f"insufficient-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user limited balance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "limited_balance"}),
)
# Try to spend 10 x $1 concurrently (but only have $5)
async def spend_one_dollar(idx: int):
try:
return await credit_system.spend_credits(
user_id,
100, # $1
UsageTransactionMetadata(
graph_exec_id=f"insufficient-{idx}",
reason=f"Insufficient spend {idx}",
),
)
except InsufficientBalanceError:
return "FAILED"
# Run 10 concurrent spends
results = await asyncio.gather(
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
)
# Count successful vs failed
successful = [
r
for r in results
if r not in ["FAILED", None] and not isinstance(r, Exception)
]
failed = [r for r in results if r == "FAILED"]
# Exactly 5 should succeed, 5 should fail
assert len(successful) == 5, f"Expected 5 successful, got {len(successful)}"
assert len(failed) == 5, f"Expected 5 failures, got {len(failed)}"
# Final balance should be exactly 0
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_mixed_operations(server: SpinTestServer):
"""Test concurrent mix of spends, top-ups, and balance checks."""
user_id = f"mixed-test-{uuid4()}"
await create_test_user(user_id)
try:
# Initial balance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Mix of operations
async def mixed_operations():
operations = []
# 5 spends of $1 each
for i in range(5):
operations.append(
credit_system.spend_credits(
user_id,
100,
UsageTransactionMetadata(reason=f"Mixed spend {i}"),
)
)
# 3 top-ups of $2 each using internal method
for i in range(3):
operations.append(
credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": f"concurrent_topup_{i}"}),
)
)
# 10 balance checks
for i in range(10):
operations.append(credit_system.get_credits(user_id))
return await asyncio.gather(*operations, return_exceptions=True)
results = await mixed_operations()
# Check no exceptions occurred
exceptions = [
r
for r in results
if isinstance(r, Exception) and not isinstance(r, InsufficientBalanceError)
]
assert len(exceptions) == 0, f"Unexpected exceptions: {exceptions}"
# Final balance should be: 1000 - 500 + 600 = 1100
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 1100, f"Expected balance 1100, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_race_condition_exact_balance(server: SpinTestServer):
"""Test spending exact balance amount concurrently doesn't go negative."""
user_id = f"exact-balance-{uuid4()}"
await create_test_user(user_id)
try:
# Give exact amount using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=100,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "exact_amount"}),
)
# Try to spend $1 twice concurrently
async def spend_exact():
try:
return await credit_system.spend_credits(
user_id, 100, UsageTransactionMetadata(reason="Exact spend")
)
except InsufficientBalanceError:
return "FAILED"
# Both try to spend the full balance
result1, result2 = await asyncio.gather(spend_exact(), spend_exact())
# Exactly one should succeed
results = [result1, result2]
successful = [
r for r in results if r != "FAILED" and not isinstance(r, Exception)
]
failed = [r for r in results if r == "FAILED"]
assert len(successful) == 1, f"Expected 1 success, got {len(successful)}"
assert len(failed) == 1, f"Expected 1 failure, got {len(failed)}"
# Balance should be exactly 0, never negative
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_onboarding_reward_idempotency(server: SpinTestServer):
"""Test that onboarding rewards are idempotent (can't be claimed twice)."""
user_id = f"onboarding-test-{uuid4()}"
await create_test_user(user_id)
try:
# Use WELCOME step which is defined in the OnboardingStep enum
# Try to claim same reward multiple times concurrently
async def claim_reward():
try:
result = await credit_system.onboarding_reward(
user_id, 500, prisma.enums.OnboardingStep.WELCOME
)
return "SUCCESS" if result else "DUPLICATE"
except Exception as e:
print(f"Claim reward failed: {e}")
return "FAILED"
# Try 5 concurrent claims of the same reward
results = await asyncio.gather(*[claim_reward() for _ in range(5)])
# Count results
success_count = results.count("SUCCESS")
failed_count = results.count("FAILED")
# At least one should succeed, others should be duplicates
assert success_count >= 1, f"At least one claim should succeed, got {results}"
assert failed_count == 0, f"No claims should fail, got {results}"
# Check balance - should only have 500, not 2500
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 500, f"Expected balance 500, got {final_balance}"
# Check only one transaction exists
transactions = await CreditTransaction.prisma().find_many(
where={
"userId": user_id,
"type": prisma.enums.CreditTransactionType.GRANT,
"transactionKey": f"REWARD-{user_id}-WELCOME",
}
)
assert (
len(transactions) == 1
), f"Expected 1 reward transaction, got {len(transactions)}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_integer_overflow_protection(server: SpinTestServer):
"""Test that integer overflow is prevented by clamping to POSTGRES_INT_MAX."""
user_id = f"overflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Try to add amount that would overflow
max_int = POSTGRES_INT_MAX
# First, set balance near max
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": max_int - 100},
"update": {"balance": max_int - 100},
},
)
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
await credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "overflow_protection"}),
)
# Balance should be clamped to max_int, not overflowed
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == max_int
), f"Balance should be clamped to {max_int}, got {final_balance}"
# Verify transaction was created with clamped amount
transactions = await CreditTransaction.prisma().find_many(
where={
"userId": user_id,
"type": prisma.enums.CreditTransactionType.TOP_UP,
},
order={"createdAt": "desc"},
)
assert len(transactions) > 0, "Transaction should be created"
assert (
transactions[0].runningBalance == max_int
), "Transaction should show clamped balance"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_high_concurrency_stress(server: SpinTestServer):
"""Stress test with many concurrent operations."""
user_id = f"stress-test-{uuid4()}"
await create_test_user(user_id)
try:
# Initial balance using internal method (bypasses Stripe)
initial_balance = 10000 # $100
await credit_system._add_transaction(
user_id=user_id,
amount=initial_balance,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "stress_test_balance"}),
)
# Run many concurrent operations
async def random_operation(idx: int):
operation = random.choice(["spend", "check"])
if operation == "spend":
amount = random.randint(1, 50) # $0.01 to $0.50
try:
return (
"spend",
amount,
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(reason=f"Stress {idx}"),
),
)
except InsufficientBalanceError:
return ("spend_failed", amount, None)
else:
balance = await credit_system.get_credits(user_id)
return ("check", 0, balance)
# Run 100 concurrent operations
results = await asyncio.gather(
*[random_operation(i) for i in range(100)], return_exceptions=True
)
# Calculate expected final balance
total_spent = sum(
r[1]
for r in results
if not isinstance(r, Exception) and isinstance(r, tuple) and r[0] == "spend"
)
expected_balance = initial_balance - total_spent
# Verify final balance
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == expected_balance
), f"Expected {expected_balance}, got {final_balance}"
assert final_balance >= 0, "Balance went negative!"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestServer):
"""Test multiple concurrent spends when there's sufficient balance for all."""
user_id = f"multi-spend-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user 150 balance ($1.50) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=150,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "sufficient_balance"}),
)
# Track individual timing to see serialization
timings = {}
async def spend_with_detailed_timing(amount: int, label: str):
start = asyncio.get_event_loop().time()
try:
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(
graph_exec_id=f"concurrent-{label}",
reason=f"Concurrent spend {label}",
),
)
end = asyncio.get_event_loop().time()
timings[label] = {"start": start, "end": end, "duration": end - start}
return f"{label}-SUCCESS"
except Exception as e:
end = asyncio.get_event_loop().time()
timings[label] = {
"start": start,
"end": end,
"duration": end - start,
"error": str(e),
}
return f"{label}-FAILED: {e}"
# Run concurrent spends: 10, 20, 30 (total 60, well under 150)
overall_start = asyncio.get_event_loop().time()
results = await asyncio.gather(
spend_with_detailed_timing(10, "spend-10"),
spend_with_detailed_timing(20, "spend-20"),
spend_with_detailed_timing(30, "spend-30"),
return_exceptions=True,
)
overall_end = asyncio.get_event_loop().time()
print(f"Results: {results}")
print(f"Overall duration: {overall_end - overall_start:.4f}s")
# Analyze timing to detect serialization vs true concurrency
print("\nTiming analysis:")
for label, timing in timings.items():
print(
f" {label}: started at {timing['start']:.4f}, ended at {timing['end']:.4f}, duration {timing['duration']:.4f}s"
)
# Check if operations overlapped (true concurrency) or were serialized
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
print("\nExecution order by start time:")
for i, (label, timing) in enumerate(sorted_timings):
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
# Check for overlap (true concurrency) vs serialization
overlaps = []
for i in range(len(sorted_timings) - 1):
current = sorted_timings[i]
next_op = sorted_timings[i + 1]
if current[1]["end"] > next_op[1]["start"]:
overlaps.append(f"{current[0]} overlaps with {next_op[0]}")
if overlaps:
print(f"✅ TRUE CONCURRENCY detected: {overlaps}")
else:
print("🔒 SERIALIZATION detected: No overlapping execution times")
# Check final balance
final_balance = await credit_system.get_credits(user_id)
print(f"Final balance: {final_balance}")
# Count successes/failures
successful = [r for r in results if "SUCCESS" in str(r)]
failed = [r for r in results if "FAILED" in str(r)]
print(f"Successful: {len(successful)}, Failed: {len(failed)}")
# All should succeed since 150 - (10 + 20 + 30) = 90 > 0
assert (
len(successful) == 3
), f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
assert final_balance == 90, f"Expected balance 90, got {final_balance}"
# Check transaction timestamps to confirm database-level serialization
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE},
order={"createdAt": "asc"},
)
print("\nDatabase transaction order (by createdAt):")
for i, tx in enumerate(transactions):
print(
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
)
# Verify running balances are chronologically consistent (ordered by createdAt)
actual_balances = [
tx.runningBalance for tx in transactions if tx.runningBalance is not None
]
print(f"Running balances: {actual_balances}")
# The balances should be valid intermediate states regardless of execution order
# Starting balance: 150, spending 10+20+30=60, so final should be 90
# The intermediate balances depend on execution order but should all be valid
expected_possible_balances = {
# If order is 10, 20, 30: [140, 120, 90]
# If order is 10, 30, 20: [140, 110, 90]
# If order is 20, 10, 30: [130, 120, 90]
# If order is 20, 30, 10: [130, 100, 90]
# If order is 30, 10, 20: [120, 110, 90]
# If order is 30, 20, 10: [120, 100, 90]
90,
100,
110,
120,
130,
140, # All possible intermediate balances
}
# Verify all balances are valid intermediate states
for balance in actual_balances:
assert (
balance in expected_possible_balances
), f"Invalid balance {balance}, expected one of {expected_possible_balances}"
# Final balance should always be 90 (150 - 60)
assert (
min(actual_balances) == 90
), f"Final balance should be 90, got {min(actual_balances)}"
# The final transaction should always have balance 90
# The other transactions should have valid intermediate balances
assert (
90 in actual_balances
), f"Final balance 90 should be in actual_balances: {actual_balances}"
# All balances should be >= 90 (the final state)
assert all(
balance >= 90 for balance in actual_balances
), f"All balances should be >= 90, got {actual_balances}"
# CRITICAL: Transactions are atomic but can complete in any order
# What matters is that all running balances are valid intermediate states
# Each balance should be between 90 (final) and 140 (after first transaction)
for balance in actual_balances:
assert (
90 <= balance <= 140
), f"Balance {balance} is outside valid range [90, 140]"
# Final balance (minimum) should always be 90
assert (
min(actual_balances) == 90
), f"Final balance should be 90, got {min(actual_balances)}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_prove_database_locking_behavior(server: SpinTestServer):
"""Definitively prove whether database locking causes waiting vs failures."""
user_id = f"locking-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set balance to exact amount that can handle all spends using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=60, # Exactly 10+20+30
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "exact_amount_test"}),
)
async def spend_with_precise_timing(amount: int, label: str):
request_start = asyncio.get_event_loop().time()
db_operation_start = asyncio.get_event_loop().time()
try:
# Add a small delay to increase chance of true concurrency
await asyncio.sleep(0.001)
db_operation_start = asyncio.get_event_loop().time()
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(
graph_exec_id=f"locking-{label}",
reason=f"Locking test {label}",
),
)
db_operation_end = asyncio.get_event_loop().time()
return {
"label": label,
"status": "SUCCESS",
"request_start": request_start,
"db_start": db_operation_start,
"db_end": db_operation_end,
"db_duration": db_operation_end - db_operation_start,
}
except Exception as e:
db_operation_end = asyncio.get_event_loop().time()
return {
"label": label,
"status": "FAILED",
"error": str(e),
"request_start": request_start,
"db_start": db_operation_start,
"db_end": db_operation_end,
"db_duration": db_operation_end - db_operation_start,
}
# Launch all requests simultaneously
results = await asyncio.gather(
spend_with_precise_timing(10, "A"),
spend_with_precise_timing(20, "B"),
spend_with_precise_timing(30, "C"),
return_exceptions=True,
)
print("\n🔍 LOCKING BEHAVIOR ANALYSIS:")
print("=" * 50)
successful = [
r for r in results if isinstance(r, dict) and r.get("status") == "SUCCESS"
]
failed = [
r for r in results if isinstance(r, dict) and r.get("status") == "FAILED"
]
print(f"✅ Successful operations: {len(successful)}")
print(f"❌ Failed operations: {len(failed)}")
if len(failed) > 0:
print(
"\n🚫 CONCURRENT FAILURES - Some requests failed due to insufficient balance:"
)
for result in failed:
if isinstance(result, dict):
print(
f" {result['label']}: {result.get('error', 'Unknown error')}"
)
if len(successful) == 3:
print(
"\n🔒 SERIALIZATION CONFIRMED - All requests succeeded, indicating they were queued:"
)
# Sort by actual execution time to see order
dict_results = [r for r in results if isinstance(r, dict)]
sorted_results = sorted(dict_results, key=lambda x: x["db_start"])
for i, result in enumerate(sorted_results):
print(
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
)
# Check if any operations overlapped at the database level
print("\n⏱️ Database operation timeline:")
for result in sorted_results:
print(
f" {result['label']}: {result['db_start']:.4f} -> {result['db_end']:.4f}"
)
# Verify final state
final_balance = await credit_system.get_credits(user_id)
print(f"\n💰 Final balance: {final_balance}")
if len(successful) == 3:
assert (
final_balance == 0
), f"If all succeeded, balance should be 0, got {final_balance}"
print(
"✅ CONCLUSION: Database row locking causes requests to WAIT and execute serially"
)
else:
print(
"❌ CONCLUSION: Some requests failed, indicating different concurrency behavior"
)
finally:
await cleanup_test_user(user_id)

View File

@@ -1,277 +0,0 @@
"""
Integration tests for credit system to catch SQL enum casting issues.
These tests run actual database operations to ensure SQL queries work correctly,
which would have caught the CreditTransactionType enum casting bug.
"""
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import (
AutoTopUpConfig,
BetaUserCredit,
UsageTransactionMetadata,
get_auto_top_up,
set_auto_top_up,
)
from backend.util.json import SafeJson
@pytest.fixture
async def cleanup_test_user():
"""Clean up test user data before and after tests."""
import uuid
user_id = str(uuid.uuid4()) # Use unique user ID for each test
# Create the user first
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"topUpConfig": SafeJson({}),
"timezone": "UTC",
}
)
except Exception:
# User might already exist, that's fine
pass
yield user_id
# Cleanup after test
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
# Clear auto-top-up config before deleting user
await User.prisma().update(
where={"id": user_id}, data={"topUpConfig": SafeJson({})}
)
await User.prisma().delete(where={"id": user_id})
@pytest.mark.asyncio(loop_scope="session")
async def test_credit_transaction_enum_casting_integration(cleanup_test_user):
"""
Integration test to verify CreditTransactionType enum casting works in SQL queries.
This test would have caught the enum casting bug where PostgreSQL expected
platform."CreditTransactionType" but got "CreditTransactionType".
"""
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# Test each transaction type to ensure enum casting works
test_cases = [
(CreditTransactionType.TOP_UP, 100, "Test top-up"),
(CreditTransactionType.USAGE, -50, "Test usage"),
(CreditTransactionType.GRANT, 200, "Test grant"),
(CreditTransactionType.REFUND, -25, "Test refund"),
(CreditTransactionType.CARD_CHECK, 0, "Test card check"),
]
for transaction_type, amount, reason in test_cases:
metadata = SafeJson({"reason": reason, "test": "enum_casting"})
# This call would fail with enum casting error before the fix
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=amount,
transaction_type=transaction_type,
metadata=metadata,
is_active=True,
)
# Verify transaction was created with correct type
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.type == transaction_type
assert transaction.amount == amount
assert transaction.metadata is not None
# Verify metadata content
assert transaction.metadata["reason"] == reason
assert transaction.metadata["test"] == "enum_casting"
@pytest.mark.asyncio(loop_scope="session")
async def test_auto_top_up_integration(cleanup_test_user, monkeypatch):
"""
Integration test for auto-top-up functionality that triggers enum casting.
This tests the complete auto-top-up flow which involves SQL queries with
CreditTransactionType enums, ensuring enum casting works end-to-end.
"""
# Enable credits for this test
from backend.data.credit import settings
monkeypatch.setattr(settings.config, "enable_credit", True)
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# First add some initial credits so we can test the configuration and subsequent behavior
balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=50, # Below threshold that we'll set
transaction_type=CreditTransactionType.GRANT,
metadata=SafeJson({"reason": "Initial credits before auto top-up config"}),
)
assert balance == 50
# Configure auto top-up with threshold above current balance
config = AutoTopUpConfig(threshold=100, amount=500)
await set_auto_top_up(user_id, config)
# Verify configuration was saved but no immediate top-up occurred
current_balance = await credit_system.get_credits(user_id)
assert current_balance == 50 # Balance should be unchanged
# Simulate spending credits that would trigger auto top-up
# This involves multiple SQL operations with enum casting
try:
metadata = UsageTransactionMetadata(reason="Test spend to trigger auto top-up")
await credit_system.spend_credits(user_id=user_id, cost=10, metadata=metadata)
# The auto top-up mechanism should have been triggered
# Verify the transaction types were handled correctly
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id}, order={"createdAt": "desc"}
)
# Should have at least: GRANT (initial), USAGE (spend), and TOP_UP (auto top-up)
assert len(transactions) >= 3
# Verify different transaction types exist and enum casting worked
transaction_types = {t.type for t in transactions}
assert CreditTransactionType.GRANT in transaction_types
assert CreditTransactionType.USAGE in transaction_types
assert (
CreditTransactionType.TOP_UP in transaction_types
) # Auto top-up should have triggered
except Exception as e:
# If this fails with enum casting error, the test successfully caught the bug
if "CreditTransactionType" in str(e) and (
"cast" in str(e).lower() or "type" in str(e).lower()
):
pytest.fail(f"Enum casting error detected: {e}")
else:
# Re-raise other unexpected errors
raise
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_enum_casting_integration(cleanup_test_user):
"""
Integration test for _enable_transaction with enum casting.
Tests the scenario where inactive transactions are enabled, which also
involves SQL queries with CreditTransactionType enum casting.
"""
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# Create an inactive transaction
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=100,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"reason": "Inactive transaction test"}),
is_active=False, # Create as inactive
)
# Balance should be 0 since transaction is inactive
assert balance == 0
# Enable the transaction with new metadata
enable_metadata = SafeJson(
{
"payment_method": "test_payment",
"activation_reason": "Integration test activation",
}
)
# This would fail with enum casting error before the fix
final_balance = await credit_system._enable_transaction(
transaction_key=tx_key,
user_id=user_id,
metadata=enable_metadata,
)
# Now balance should reflect the activated transaction
assert final_balance == 100
# Verify transaction was properly enabled with correct enum type
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.isActive is True
assert transaction.type == CreditTransactionType.TOP_UP
assert transaction.runningBalance == 100
# Verify metadata was updated
assert transaction.metadata is not None
assert transaction.metadata["payment_method"] == "test_payment"
assert transaction.metadata["activation_reason"] == "Integration test activation"
@pytest.mark.asyncio(loop_scope="session")
async def test_auto_top_up_configuration_storage(cleanup_test_user, monkeypatch):
"""
Test that auto-top-up configuration is properly stored and retrieved.
The immediate top-up logic is handled by the API routes, not the core
set_auto_top_up function. This test verifies the configuration is correctly
saved and can be retrieved.
"""
# Enable credits for this test
from backend.data.credit import settings
monkeypatch.setattr(settings.config, "enable_credit", True)
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# Set initial balance
balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=50,
transaction_type=CreditTransactionType.GRANT,
metadata=SafeJson({"reason": "Initial balance for config test"}),
)
assert balance == 50
# Configure auto top-up
config = AutoTopUpConfig(threshold=100, amount=200)
await set_auto_top_up(user_id, config)
# Verify the configuration was saved
retrieved_config = await get_auto_top_up(user_id)
assert retrieved_config.threshold == config.threshold
assert retrieved_config.amount == config.amount
# Verify balance is unchanged (no immediate top-up from set_auto_top_up)
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 50 # Should be unchanged
# Verify no immediate auto-top-up transaction was created by set_auto_top_up
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id}, order={"createdAt": "desc"}
)
# Should only have the initial GRANT transaction
assert len(transactions) == 1
assert transactions[0].type == CreditTransactionType.GRANT

View File

@@ -1,141 +0,0 @@
"""
Tests for credit system metadata handling to ensure JSON casting works correctly.
This test verifies that metadata parameters are properly serialized when passed
to raw SQL queries with JSONB columns.
"""
# type: ignore
from typing import Any
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance
from backend.data.credit import BetaUserCredit
from backend.data.user import DEFAULT_USER_ID
from backend.util.json import SafeJson
@pytest.fixture
async def setup_test_user():
"""Setup test user and cleanup after test."""
user_id = DEFAULT_USER_ID
# Cleanup before test
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
yield user_id
# Cleanup after test
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
@pytest.mark.asyncio(loop_scope="session")
async def test_metadata_json_serialization(setup_test_user):
"""Test that metadata is properly serialized for JSONB column in raw SQL."""
user_id = setup_test_user
credit_system = BetaUserCredit(1000)
# Test with complex metadata that would fail if not properly serialized
complex_metadata = SafeJson(
{
"graph_exec_id": "test-12345",
"reason": "Testing metadata serialization",
"nested_data": {
"key1": "value1",
"key2": ["array", "of", "values"],
"key3": {"deeply": {"nested": "object"}},
},
"special_chars": "Testing 'quotes' and \"double quotes\" and unicode: 🚀",
}
)
# This should work without throwing a JSONB casting error
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=500, # $5 top-up
transaction_type=CreditTransactionType.TOP_UP,
metadata=complex_metadata,
is_active=True,
)
# Verify the transaction was created successfully
assert balance == 500
# Verify the metadata was stored correctly in the database
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.metadata is not None
# Verify the metadata contains our complex data
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
assert metadata_dict["graph_exec_id"] == "test-12345"
assert metadata_dict["reason"] == "Testing metadata serialization"
assert metadata_dict["nested_data"]["key1"] == "value1"
assert metadata_dict["nested_data"]["key3"]["deeply"]["nested"] == "object"
assert (
metadata_dict["special_chars"]
== "Testing 'quotes' and \"double quotes\" and unicode: 🚀"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_metadata_serialization(setup_test_user):
"""Test that _enable_transaction also handles metadata JSON serialization correctly."""
user_id = setup_test_user
credit_system = BetaUserCredit(1000)
# First create an inactive transaction
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=300,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"initial": "inactive_transaction"}),
is_active=False, # Create as inactive
)
# Initial balance should be 0 because transaction is inactive
assert balance == 0
# Now enable the transaction with new metadata
enable_metadata = SafeJson(
{
"payment_method": "stripe",
"payment_intent": "pi_test_12345",
"activation_reason": "Payment confirmed",
"complex_data": {"array": [1, 2, 3], "boolean": True, "null_value": None},
}
)
# This should work without JSONB casting errors
final_balance = await credit_system._enable_transaction(
transaction_key=tx_key,
user_id=user_id,
metadata=enable_metadata,
)
# Now balance should reflect the activated transaction
assert final_balance == 300
# Verify the metadata was updated correctly
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.isActive is True
# Verify the metadata was updated with enable_metadata
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
assert metadata_dict["payment_method"] == "stripe"
assert metadata_dict["payment_intent"] == "pi_test_12345"
assert metadata_dict["complex_data"]["array"] == [1, 2, 3]
assert metadata_dict["complex_data"]["boolean"] is True
assert metadata_dict["complex_data"]["null_value"] is None

View File

@@ -1,372 +0,0 @@
"""
Tests for credit system refund and dispute operations.
These tests ensure that refund operations (deduct_credits, handle_dispute)
are atomic and maintain data consistency.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
import stripe
from prisma.enums import CreditTransactionType
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
credit_system = UserCredit()
# Test user ID for refund tests
REFUND_TEST_USER_ID = "refund-test-user"
async def setup_test_user_with_topup():
"""Create a test user with initial balance and a top-up transaction."""
# Clean up any existing data
await CreditRefundRequest.prisma().delete_many(
where={"userId": REFUND_TEST_USER_ID}
)
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
# Create user
await User.prisma().create(
data={
"id": REFUND_TEST_USER_ID,
"email": f"{REFUND_TEST_USER_ID}@example.com",
"name": "Refund Test User",
}
)
# Create user balance
await UserBalance.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"balance": 1000, # $10
}
)
# Create a top-up transaction that can be refunded
topup_tx = await CreditTransaction.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 1000,
"type": CreditTransactionType.TOP_UP,
"transactionKey": "pi_test_12345",
"runningBalance": 1000,
"isActive": True,
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
}
)
return topup_tx
async def cleanup_test_user():
"""Clean up test data."""
await CreditRefundRequest.prisma().delete_many(
where={"userId": REFUND_TEST_USER_ID}
)
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
@pytest.mark.asyncio(loop_scope="session")
async def test_deduct_credits_atomic(server: SpinTestServer):
"""Test that deduct_credits is atomic and creates transaction correctly."""
topup_tx = await setup_test_user_with_topup()
try:
# Create a mock refund object
refund = MagicMock(spec=stripe.Refund)
refund.id = "re_test_refund_123"
refund.payment_intent = topup_tx.transactionKey
refund.amount = 500 # Refund $5 of the $10 top-up
refund.status = "succeeded"
refund.reason = "requested_by_customer"
refund.created = int(datetime.now(timezone.utc).timestamp())
# Create refund request record (simulating webhook flow)
await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 500,
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
"reason": "Test refund",
}
)
# Call deduct_credits
await credit_system.deduct_credits(refund)
# Verify the user's balance was deducted
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert (
user_balance.balance == 500
), f"Expected balance 500, got {user_balance.balance}"
# Verify refund transaction was created
refund_tx = await CreditTransaction.prisma().find_first(
where={
"userId": REFUND_TEST_USER_ID,
"type": CreditTransactionType.REFUND,
"transactionKey": refund.id,
}
)
assert refund_tx is not None
assert refund_tx.amount == -500
assert refund_tx.runningBalance == 500
assert refund_tx.isActive
# Verify refund request was updated
refund_request = await CreditRefundRequest.prisma().find_first(
where={
"userId": REFUND_TEST_USER_ID,
"transactionKey": topup_tx.transactionKey,
}
)
assert refund_request is not None
assert (
refund_request.result
== "The refund request has been approved, the amount will be credited back to your account."
)
finally:
await cleanup_test_user()
@pytest.mark.asyncio(loop_scope="session")
async def test_deduct_credits_user_not_found(server: SpinTestServer):
"""Test that deduct_credits raises error if transaction not found (which means user doesn't exist)."""
# Create a mock refund object that references a non-existent payment intent
refund = MagicMock(spec=stripe.Refund)
refund.id = "re_test_refund_nonexistent"
refund.payment_intent = "pi_test_nonexistent" # This payment intent doesn't exist
refund.amount = 500
refund.status = "succeeded"
refund.reason = "requested_by_customer"
refund.created = int(datetime.now(timezone.utc).timestamp())
# Should raise error for missing transaction
with pytest.raises(Exception): # Should raise NotFoundError for missing transaction
await credit_system.deduct_credits(refund)
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.data.credit.settings")
@patch("stripe.Dispute.modify")
@patch("backend.data.credit.get_user_by_id")
async def test_handle_dispute_with_sufficient_balance(
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
):
"""Test handling dispute when user has sufficient balance (dispute gets closed)."""
topup_tx = await setup_test_user_with_topup()
try:
# Mock settings to have a low tolerance threshold
mock_settings.config.refund_credit_tolerance_threshold = 0
# Mock the user lookup
mock_user = MagicMock()
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
mock_get_user.return_value = mock_user
# Create a mock dispute object for small amount (user has 1000, disputing 100)
dispute = MagicMock(spec=stripe.Dispute)
dispute.id = "dp_test_dispute_123"
dispute.payment_intent = topup_tx.transactionKey
dispute.amount = 100 # Small dispute amount
dispute.status = "pending"
dispute.reason = "fraudulent"
dispute.created = int(datetime.now(timezone.utc).timestamp())
# Mock the close method to prevent real API calls
dispute.close = MagicMock()
# Handle the dispute
await credit_system.handle_dispute(dispute)
# Verify dispute.close() was called (since user has sufficient balance)
dispute.close.assert_called_once()
# Verify no stripe evidence was added since dispute was closed
mock_stripe_modify.assert_not_called()
# Verify the user's balance was NOT deducted (dispute was closed)
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert (
user_balance.balance == 1000
), f"Balance should remain 1000, got {user_balance.balance}"
finally:
await cleanup_test_user()
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.data.credit.settings")
@patch("stripe.Dispute.modify")
@patch("backend.data.credit.get_user_by_id")
async def test_handle_dispute_with_insufficient_balance(
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
):
"""Test handling dispute when user has insufficient balance (evidence gets added)."""
topup_tx = await setup_test_user_with_topup()
# Save original method for restoration before any try blocks
original_get_history = credit_system.get_transaction_history
try:
# Mock settings to have a high tolerance threshold so dispute isn't closed
mock_settings.config.refund_credit_tolerance_threshold = 2000
# Mock the user lookup
mock_user = MagicMock()
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
mock_get_user.return_value = mock_user
# Mock the transaction history method to return an async result
from unittest.mock import AsyncMock
mock_history = MagicMock()
mock_history.transactions = []
credit_system.get_transaction_history = AsyncMock(return_value=mock_history)
# Create a mock dispute object for full amount (user has 1000, disputing 1000)
dispute = MagicMock(spec=stripe.Dispute)
dispute.id = "dp_test_dispute_pending"
dispute.payment_intent = topup_tx.transactionKey
dispute.amount = 1000
dispute.status = "warning_needs_response"
dispute.created = int(datetime.now(timezone.utc).timestamp())
# Mock the close method to prevent real API calls
dispute.close = MagicMock()
# Handle the dispute (evidence should be added)
await credit_system.handle_dispute(dispute)
# Verify dispute.close() was NOT called (insufficient balance after tolerance)
dispute.close.assert_not_called()
# Verify stripe evidence was added since dispute wasn't closed
mock_stripe_modify.assert_called_once()
# Verify the user's balance was NOT deducted (handle_dispute doesn't deduct credits)
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert user_balance.balance == 1000, "Balance should remain unchanged"
finally:
credit_system.get_transaction_history = original_get_history
await cleanup_test_user()
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_refunds(server: SpinTestServer):
"""Test that concurrent refunds are handled atomically."""
import asyncio
topup_tx = await setup_test_user_with_topup()
try:
# Create multiple refund requests
refund_requests = []
for i in range(5):
req = await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 100, # $1 each
"transactionKey": topup_tx.transactionKey,
"reason": f"Test refund {i}",
}
)
refund_requests.append(req)
# Create refund tasks to run concurrently
async def process_refund(index: int):
refund = MagicMock(spec=stripe.Refund)
refund.id = f"re_test_concurrent_{index}"
refund.payment_intent = topup_tx.transactionKey
refund.amount = 100 # $1 refund
refund.status = "succeeded"
refund.reason = "requested_by_customer"
refund.created = int(datetime.now(timezone.utc).timestamp())
try:
await credit_system.deduct_credits(refund)
return "success"
except Exception as e:
return f"error: {e}"
# Run refunds concurrently
results = await asyncio.gather(
*[process_refund(i) for i in range(5)], return_exceptions=True
)
# All should succeed
assert all(r == "success" for r in results), f"Some refunds failed: {results}"
# Verify final balance - with non-atomic implementation, this will demonstrate race condition
# EXPECTED BEHAVIOR: Due to race conditions, not all refunds will be properly processed
# The balance will be incorrect (higher than expected) showing lost updates
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
# With atomic implementation, this should be 500 (1000 - 5*100)
# With current non-atomic implementation, this will likely be wrong due to race conditions
print(f"DEBUG: Final balance = {user_balance.balance}, expected = 500")
# With atomic implementation, all 5 refunds should process correctly
assert (
user_balance.balance == 500
), f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
# Verify all refund transactions exist
refund_txs = await CreditTransaction.prisma().find_many(
where={
"userId": REFUND_TEST_USER_ID,
"type": CreditTransactionType.REFUND,
}
)
assert (
len(refund_txs) == 5
), f"Expected 5 refund transactions, got {len(refund_txs)}"
running_balances: set[int] = {
tx.runningBalance for tx in refund_txs if tx.runningBalance is not None
}
# Verify all balances are valid intermediate states
for balance in running_balances:
assert (
500 <= balance <= 1000
), f"Invalid balance {balance}, should be between 500 and 1000"
# Final balance should be present
assert (
500 in running_balances
), f"Final balance 500 should be in {running_balances}"
# All balances should be unique and form a valid sequence
sorted_balances = sorted(running_balances, reverse=True)
assert (
len(sorted_balances) == 5
), f"Expected 5 unique balances, got {len(sorted_balances)}"
finally:
await cleanup_test_user()

View File

@@ -1,8 +1,8 @@
from datetime import datetime, timedelta, timezone
from datetime import datetime, timezone
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance
from prisma.models import CreditTransaction
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
@@ -19,24 +19,14 @@ user_credit = BetaUserCredit(REFILL_VALUE)
async def disable_test_user_transactions():
await CreditTransaction.prisma().delete_many(where={"userId": DEFAULT_USER_ID})
# Also reset the balance to 0 and set updatedAt to old date to trigger monthly refill
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
"update": {"balance": 0, "updatedAt": old_date},
},
)
async def top_up(amount: int):
balance, _ = await user_credit._add_transaction(
await user_credit._add_transaction(
DEFAULT_USER_ID,
amount,
CreditTransactionType.TOP_UP,
)
return balance
async def spend_credits(entry: NodeExecutionEntry) -> int:
@@ -121,90 +111,29 @@ async def test_block_credit_top_up(server: SpinTestServer):
@pytest.mark.asyncio(loop_scope="session")
async def test_block_credit_reset(server: SpinTestServer):
"""Test that BetaUserCredit provides monthly refills correctly."""
await disable_test_user_transactions()
month1 = 1
month2 = 2
# Save original time_now function for restoration
original_time_now = user_credit.time_now
# set the calendar to month 2 but use current time from now
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
month=month2, day=1
)
month2credit = await user_credit.get_credits(DEFAULT_USER_ID)
try:
# Test month 1 behavior
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
user_credit.time_now = lambda: month1
# Month 1 result should only affect month 1
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
month=month1, day=1
)
month1credit = await user_credit.get_credits(DEFAULT_USER_ID)
await top_up(100)
assert await user_credit.get_credits(DEFAULT_USER_ID) == month1credit + 100
# First call in month 1 should trigger refill
balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert balance == REFILL_VALUE # Should get 1000 credits
# Manually create a transaction with month 1 timestamp to establish history
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": 100,
"type": CreditTransactionType.TOP_UP,
"runningBalance": 1100,
"isActive": True,
"createdAt": month1, # Set specific timestamp
}
)
# Update user balance to match
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
"update": {"balance": 1100},
},
)
# Now test month 2 behavior
month2 = datetime.now(timezone.utc).replace(month=2, day=1)
user_credit.time_now = lambda: month2
# In month 2, since balance (1100) > refill (1000), no refill should happen
month2_balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert month2_balance == 1100 # Balance persists, no reset
# Now test the refill behavior when balance is low
# Set balance below refill threshold
await UserBalance.prisma().update(
where={"userId": DEFAULT_USER_ID}, data={"balance": 400}
)
# Create a month 2 transaction to update the last transaction time
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": -700, # Spent 700 to get to 400
"type": CreditTransactionType.USAGE,
"runningBalance": 400,
"isActive": True,
"createdAt": month2,
}
)
# Move to month 3
month3 = datetime.now(timezone.utc).replace(month=3, day=1)
user_credit.time_now = lambda: month3
# Should get refilled since balance (400) < refill value (1000)
month3_balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert month3_balance == REFILL_VALUE # Should be refilled to 1000
# Verify the refill transaction was created
refill_tx = await CreditTransaction.prisma().find_first(
where={
"userId": DEFAULT_USER_ID,
"type": CreditTransactionType.GRANT,
"transactionKey": {"contains": "MONTHLY-CREDIT-TOP-UP"},
},
order={"createdAt": "desc"},
)
assert refill_tx is not None, "Monthly refill transaction should be created"
assert refill_tx.amount == 600, "Refill should be 600 (1000 - 400)"
finally:
# Restore original time_now function
user_credit.time_now = original_time_now
# Month 2 balance is unaffected
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
month=month2, day=1
)
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -1,361 +0,0 @@
"""
Test underflow protection for cumulative refunds and negative transactions.
This test ensures that when multiple large refunds are processed, the user balance
doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound issues.
"""
import asyncio
from uuid import uuid4
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
from backend.util.test import SpinTestServer
async def create_test_user(user_id: str) -> None:
"""Create a test user for underflow tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their transactions."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_debug_underflow_step_by_step(server: SpinTestServer):
"""Debug underflow behavior step by step."""
credit_system = UserCredit()
user_id = f"debug-underflow-{uuid4()}"
await create_test_user(user_id)
try:
print(f"POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
# Test 1: Set up balance close to underflow threshold
print("\n=== Test 1: Setting up balance close to underflow threshold ===")
# First, manually set balance to a value very close to POSTGRES_INT_MIN
# We'll set it to POSTGRES_INT_MIN + 100, then try to subtract 200
# This should trigger underflow protection: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
initial_balance_target = POSTGRES_INT_MIN + 100
# Use direct database update to set the balance close to underflow
from prisma.models import UserBalance
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance_target},
"update": {"balance": initial_balance_target},
},
)
current_balance = await credit_system.get_credits(user_id)
print(f"Set balance to: {current_balance}")
assert current_balance == initial_balance_target
# Test 2: Apply amount that should cause underflow
print("\n=== Test 2: Testing underflow protection ===")
test_amount = (
-200
) # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
expected_without_protection = current_balance + test_amount
print(f"Current balance: {current_balance}")
print(f"Test amount: {test_amount}")
print(f"Without protection would be: {expected_without_protection}")
print(f"Should be clamped to POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
# Apply the amount that should trigger underflow protection
balance_result, _ = await credit_system._add_transaction(
user_id=user_id,
amount=test_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
print(f"Actual result: {balance_result}")
# Check if underflow protection worked
assert (
balance_result == POSTGRES_INT_MIN
), f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
# Test 3: Edge case - exactly at POSTGRES_INT_MIN
print("\n=== Test 3: Testing exact POSTGRES_INT_MIN boundary ===")
# Set balance to exactly POSTGRES_INT_MIN
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
"update": {"balance": POSTGRES_INT_MIN},
},
)
edge_balance = await credit_system.get_credits(user_id)
print(f"Balance set to exactly POSTGRES_INT_MIN: {edge_balance}")
# Try to subtract 1 - should stay at POSTGRES_INT_MIN
edge_result, _ = await credit_system._add_transaction(
user_id=user_id,
amount=-1,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
print(f"After subtracting 1: {edge_result}")
assert (
edge_result == POSTGRES_INT_MIN
), f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_underflow_protection_large_refunds(server: SpinTestServer):
"""Test that large cumulative refunds don't cause integer underflow."""
credit_system = UserCredit()
user_id = f"underflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set up balance close to underflow threshold to test the protection
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
# This should trigger underflow protection
from prisma.models import UserBalance
test_balance = POSTGRES_INT_MIN + 1000
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": test_balance},
"update": {"balance": test_balance},
},
)
current_balance = await credit_system.get_credits(user_id)
assert current_balance == test_balance
# Try to deduct amount that would cause underflow: test_balance + (-2000) = POSTGRES_INT_MIN - 1000
underflow_amount = -2000
expected_without_protection = (
current_balance + underflow_amount
) # Should be POSTGRES_INT_MIN - 1000
# Use _add_transaction directly with amount that would cause underflow
final_balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=underflow_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False, # Allow going negative for refunds
)
# Balance should be clamped to POSTGRES_INT_MIN, not the calculated underflow value
assert (
final_balance == POSTGRES_INT_MIN
), f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
assert (
final_balance > expected_without_protection
), f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == POSTGRES_INT_MIN
), f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
# Verify transaction was created with the underflow-protected balance
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": CreditTransactionType.REFUND},
order={"createdAt": "desc"},
)
assert len(transactions) > 0, "Refund transaction should be created"
assert (
transactions[0].runningBalance == POSTGRES_INT_MIN
), f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServer):
"""Test that multiple large refunds applied sequentially don't cause underflow."""
credit_system = UserCredit()
user_id = f"cumulative-underflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
)
# Apply multiple refunds that would cumulatively underflow
refund_amount = -300 # Each refund that would cause underflow when cumulative
# First refund: (POSTGRES_INT_MIN + 500) + (-300) = POSTGRES_INT_MIN + 200 (still above minimum)
balance_1, _ = await credit_system._add_transaction(
user_id=user_id,
amount=refund_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
# Should be above minimum for first refund
expected_balance_1 = (
initial_balance + refund_amount
) # Should be POSTGRES_INT_MIN + 200
assert (
balance_1 == expected_balance_1
), f"First refund should result in {expected_balance_1}, got {balance_1}"
assert (
balance_1 >= POSTGRES_INT_MIN
), f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
# Second refund: (POSTGRES_INT_MIN + 200) + (-300) = POSTGRES_INT_MIN - 100 (would underflow)
balance_2, _ = await credit_system._add_transaction(
user_id=user_id,
amount=refund_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
# Should be clamped to minimum due to underflow protection
assert (
balance_2 == POSTGRES_INT_MIN
), f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
# Third refund: Should stay at minimum
balance_3, _ = await credit_system._add_transaction(
user_id=user_id,
amount=refund_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
# Should still be at minimum
assert (
balance_3 == POSTGRES_INT_MIN
), f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
# Final balance check
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == POSTGRES_INT_MIN
), f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
"""Test that concurrent large refunds don't cause race condition underflow."""
credit_system = UserCredit()
user_id = f"concurrent-underflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
)
async def large_refund(amount: int, label: str):
try:
return await credit_system._add_transaction(
user_id=user_id,
amount=-amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
except Exception as e:
return f"FAILED-{label}: {e}"
# Run concurrent refunds that would cause underflow if not protected
# Each refund of 500 would cause underflow: initial_balance + (-500) could go below POSTGRES_INT_MIN
refund_amount = 500
results = await asyncio.gather(
large_refund(refund_amount, "A"),
large_refund(refund_amount, "B"),
large_refund(refund_amount, "C"),
return_exceptions=True,
)
# Check all results are valid and no underflow occurred
valid_results = []
for i, result in enumerate(results):
if isinstance(result, tuple):
balance, _ = result
assert (
balance >= POSTGRES_INT_MIN
), f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
valid_results.append(balance)
elif isinstance(result, str) and "FAILED" in result:
# Some operations might fail due to validation, that's okay
pass
else:
# Unexpected exception
assert not isinstance(
result, Exception
), f"Unexpected exception in result {i}: {result}"
# At least one operation should succeed
assert (
len(valid_results) > 0
), f"At least one refund should succeed, got results: {results}"
# All successful results should be >= POSTGRES_INT_MIN
for balance in valid_results:
assert (
balance >= POSTGRES_INT_MIN
), f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
# Final balance should be valid and at or above POSTGRES_INT_MIN
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance >= POSTGRES_INT_MIN
), f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
finally:
await cleanup_test_user(user_id)

View File

@@ -1,217 +0,0 @@
"""
Integration test to verify complete migration from User.balance to UserBalance table.
This test ensures that:
1. No User.balance queries exist in the system
2. All balance operations go through UserBalance table
3. User and UserBalance stay synchronized properly
"""
import asyncio
from datetime import datetime
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import UsageTransactionMetadata, UserCredit
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
async def create_test_user(user_id: str) -> None:
"""Create a test user for migration tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their data."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_user_balance_migration_complete(server: SpinTestServer):
"""Test that User table balance is never used and UserBalance is source of truth."""
credit_system = UserCredit()
user_id = f"migration-test-{datetime.now().timestamp()}"
await create_test_user(user_id)
try:
# 1. Verify User table does NOT have balance set initially
user = await User.prisma().find_unique(where={"id": user_id})
assert user is not None
# User.balance should not exist or should be None/0 if it exists
user_balance_attr = getattr(user, "balance", None)
if user_balance_attr is not None:
assert (
user_balance_attr == 0 or user_balance_attr is None
), f"User.balance should be 0 or None, got {user_balance_attr}"
# 2. Perform various credit operations using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "migration_test"}),
)
balance1 = await credit_system.get_credits(user_id)
assert balance1 == 1000
await credit_system.spend_credits(
user_id,
300,
UsageTransactionMetadata(
graph_exec_id="test", reason="Migration test spend"
),
)
balance2 = await credit_system.get_credits(user_id)
assert balance2 == 700
# 3. Verify UserBalance table has correct values
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 700
), f"UserBalance should be 700, got {user_balance.balance}"
# 4. CRITICAL: Verify User.balance is NEVER updated during operations
user_after = await User.prisma().find_unique(where={"id": user_id})
assert user_after is not None
user_balance_after = getattr(user_after, "balance", None)
if user_balance_after is not None:
# If User.balance exists, it should still be 0 (never updated)
assert (
user_balance_after == 0 or user_balance_after is None
), f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
# 5. Verify get_credits always returns UserBalance value, not User.balance
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == user_balance.balance
), f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_detect_stale_user_balance_queries(server: SpinTestServer):
"""Test to detect if any operations are still using User.balance instead of UserBalance."""
credit_system = UserCredit()
user_id = f"stale-query-test-{datetime.now().timestamp()}"
await create_test_user(user_id)
try:
# Create UserBalance with specific value
await UserBalance.prisma().create(
data={"userId": user_id, "balance": 5000} # $50
)
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
balance = await credit_system.get_credits(user_id)
assert (
balance == 5000
), f"Expected get_credits to return 5000 from UserBalance, got {balance}"
# Verify all operations use UserBalance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "final_verification"}),
)
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 6000, f"Expected 6000, got {final_balance}"
# Verify UserBalance table has the correct value
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 6000
), f"UserBalance should be 6000, got {user_balance.balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer):
"""Test that concurrent operations all use UserBalance locking, not User.balance."""
credit_system = UserCredit()
user_id = f"concurrent-userbalance-test-{datetime.now().timestamp()}"
await create_test_user(user_id)
try:
# Set initial balance in UserBalance
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
# Run concurrent operations to ensure they all use UserBalance atomic operations
async def concurrent_spend(amount: int, label: str):
try:
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(
graph_exec_id=f"concurrent-{label}",
reason=f"Concurrent test {label}",
),
)
return f"{label}-SUCCESS"
except Exception as e:
return f"{label}-FAILED: {e}"
# Run concurrent operations
results = await asyncio.gather(
concurrent_spend(100, "A"),
concurrent_spend(200, "B"),
concurrent_spend(300, "C"),
return_exceptions=True,
)
# All should succeed (1000 >= 100+200+300)
successful = [r for r in results if "SUCCESS" in str(r)]
assert len(successful) == 3, f"All operations should succeed, got {results}"
# Final balance should be 1000 - 600 = 400
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 400, f"Expected final balance 400, got {final_balance}"
# Verify UserBalance has correct value
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 400
), f"UserBalance should be 400, got {user_balance.balance}"
# Critical: If User.balance exists and was used, it might have wrong value
try:
user = await User.prisma().find_unique(where={"id": user_id})
user_balance_attr = getattr(user, "balance", None)
if user_balance_attr is not None:
# If User.balance exists, it should NOT be used for operations
# The fact that our final balance is correct from UserBalance proves the system is working
print(
f"✅ User.balance exists ({user_balance_attr}) but UserBalance ({user_balance.balance}) is being used correctly"
)
except Exception:
print("✅ User.balance column doesn't exist - migration is complete")
finally:
await cleanup_test_user(user_id)

View File

@@ -98,6 +98,42 @@ async def transaction(timeout: int = TRANSACTION_TIMEOUT):
yield tx
@asynccontextmanager
async def locked_transaction(key: str, timeout: int = TRANSACTION_TIMEOUT):
"""
Create a transaction and take a per-key advisory *transaction* lock.
- Uses a 64-bit lock id via hashtextextended(key, 0) to avoid 32-bit collisions.
- Bound by lock_timeout and statement_timeout so it won't block indefinitely.
- Lock is held for the duration of the transaction and auto-released on commit/rollback.
Args:
key: String lock key (e.g., "usr_trx_<uuid>").
timeout: Transaction/lock/statement timeout in milliseconds.
"""
async with transaction(timeout=timeout) as tx:
# Ensure we don't wait longer than desired
# Note: SET LOCAL doesn't support parameterized queries, must use string interpolation
await tx.execute_raw(f"SET LOCAL statement_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
await tx.execute_raw(f"SET LOCAL lock_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
# Block until acquired or lock_timeout hits
try:
await tx.execute_raw(
"SELECT pg_advisory_xact_lock(hashtextextended($1, 0))",
key,
)
except Exception as e:
# Normalize PG's lock timeout error to TimeoutError for callers
if "lock timeout" in str(e).lower():
raise TimeoutError(
f"Could not acquire lock for key={key!r} within {timeout}ms"
) from e
raise
yield tx
def get_database_schema() -> str:
"""Extract database schema from DATABASE_URL."""
parsed_url = urlparse(DATABASE_URL)

View File

@@ -1,284 +0,0 @@
"""
Utilities for handling dynamic field names with special delimiters.
Dynamic fields allow graphs to connect complex data structures using special delimiters:
- _#_ for dictionary keys (e.g., "values_#_name" → values["name"])
- _$_ for list indices (e.g., "items_$_0" → items[0])
- _@_ for object attributes (e.g., "obj_@_attr" → obj.attr)
"""
from typing import Any
from backend.util.mock import MockObject
# Dynamic field delimiters
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
DYNAMIC_DELIMITERS = (LIST_SPLIT, DICT_SPLIT, OBJC_SPLIT)
def extract_base_field_name(field_name: str) -> str:
"""
Extract the base field name from a dynamic field name by removing all dynamic suffixes.
Examples:
extract_base_field_name("values_#_name") → "values"
extract_base_field_name("items_$_0") → "items"
extract_base_field_name("obj_@_attr") → "obj"
extract_base_field_name("regular_field") → "regular_field"
Args:
field_name: The field name that may contain dynamic delimiters
Returns:
The base field name without any dynamic suffixes
"""
base_name = field_name
for delimiter in DYNAMIC_DELIMITERS:
if delimiter in base_name:
base_name = base_name.split(delimiter)[0]
return base_name
def is_dynamic_field(field_name: str) -> bool:
"""
Check if a field name contains dynamic delimiters.
Args:
field_name: The field name to check
Returns:
True if the field contains any dynamic delimiters, False otherwise
"""
return any(delimiter in field_name for delimiter in DYNAMIC_DELIMITERS)
def get_dynamic_field_description(field_name: str) -> str:
"""
Generate a description for a dynamic field based on its structure.
Args:
field_name: The full dynamic field name (e.g., "values_#_name")
Returns:
A descriptive string explaining what this dynamic field represents
"""
base_name = extract_base_field_name(field_name)
if DICT_SPLIT in field_name:
# Extract the key part after _#_
parts = field_name.split(DICT_SPLIT)
if len(parts) > 1:
key = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
return f"Dictionary field '{key}' for base field '{base_name}' ({base_name}['{key}'])"
elif LIST_SPLIT in field_name:
# Extract the index part after _$_
parts = field_name.split(LIST_SPLIT)
if len(parts) > 1:
index = parts[1].split("_")[0] if "_" in parts[1] else parts[1]
return (
f"List item {index} for base field '{base_name}' ({base_name}[{index}])"
)
elif OBJC_SPLIT in field_name:
# Extract the attribute part after _@_
parts = field_name.split(OBJC_SPLIT)
if len(parts) > 1:
# Get the full attribute name (everything after _@_)
attr = parts[1]
return f"Object attribute '{attr}' for base field '{base_name}' ({base_name}.{attr})"
return f"Value for {field_name}"
# --------------------------------------------------------------------------- #
# Dynamic field parsing and merging utilities
# --------------------------------------------------------------------------- #
def _next_delim(s: str) -> tuple[str | None, int]:
"""
Return the *earliest* delimiter appearing in `s` and its index.
If none present → (None, -1).
"""
first: str | None = None
pos = len(s) # sentinel: larger than any real index
for d in DYNAMIC_DELIMITERS:
i = s.find(d)
if 0 <= i < pos:
first, pos = d, i
return first, (pos if first else -1)
def _tokenise(path: str) -> list[tuple[str, str]] | None:
"""
Convert the raw path string (starting with a delimiter) into
[ (delimiter, identifier), … ] or None if the syntax is malformed.
"""
tokens: list[tuple[str, str]] = []
while path:
# 1. Which delimiter starts this chunk?
delim = next((d for d in DYNAMIC_DELIMITERS if path.startswith(d)), None)
if delim is None:
return None # invalid syntax
# 2. Slice off the delimiter, then up to the next delimiter (or EOS)
path = path[len(delim) :]
nxt_delim, pos = _next_delim(path)
token, path = (
path[: pos if pos != -1 else len(path)],
path[pos if pos != -1 else len(path) :],
)
if token == "":
return None # empty identifier is invalid
tokens.append((delim, token))
return tokens
def parse_execution_output(output: tuple[str, Any], name: str) -> Any:
"""
Retrieve a nested value out of `output` using the flattened *name*.
On any failure (wrong name, wrong type, out-of-range, bad path)
returns **None**.
Args:
output: Tuple of (base_name, data) representing a block output entry
name: The flattened field name to extract from the output data
Returns:
The value at the specified path, or None if not found/invalid
"""
base_name, data = output
# Exact match → whole object
if name == base_name:
return data
# Must start with the expected name
if not name.startswith(base_name):
return None
path = name[len(base_name) :]
if not path:
return None # nothing left to parse
tokens = _tokenise(path)
if tokens is None:
return None
cur: Any = data
for delim, ident in tokens:
if delim == LIST_SPLIT:
# list[index]
try:
idx = int(ident)
except ValueError:
return None
if not isinstance(cur, list) or idx >= len(cur):
return None
cur = cur[idx]
elif delim == DICT_SPLIT:
if not isinstance(cur, dict) or ident not in cur:
return None
cur = cur[ident]
elif delim == OBJC_SPLIT:
if not hasattr(cur, ident):
return None
cur = getattr(cur, ident)
else:
return None # unreachable
return cur
def _assign(container: Any, tokens: list[tuple[str, str]], value: Any) -> Any:
"""
Recursive helper that *returns* the (possibly new) container with
`value` assigned along the remaining `tokens` path.
"""
if not tokens:
return value # leaf reached
delim, ident = tokens[0]
rest = tokens[1:]
# ---------- list ----------
if delim == LIST_SPLIT:
try:
idx = int(ident)
except ValueError:
raise ValueError("index must be an integer")
if container is None:
container = []
elif not isinstance(container, list):
container = list(container) if hasattr(container, "__iter__") else []
while len(container) <= idx:
container.append(None)
container[idx] = _assign(container[idx], rest, value)
return container
# ---------- dict ----------
if delim == DICT_SPLIT:
if container is None:
container = {}
elif not isinstance(container, dict):
container = dict(container) if hasattr(container, "items") else {}
container[ident] = _assign(container.get(ident), rest, value)
return container
# ---------- object ----------
if delim == OBJC_SPLIT:
if container is None:
container = MockObject()
elif not hasattr(container, "__dict__"):
# If it's not an object, create a new one
container = MockObject()
setattr(
container,
ident,
_assign(getattr(container, ident, None), rest, value),
)
return container
return value # unreachable
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
"""
Reconstruct nested objects from a *flattened* dict of key → value.
Raises ValueError on syntactically invalid list indices.
Args:
data: Dictionary with potentially flattened dynamic field keys
Returns:
Dictionary with nested objects reconstructed from flattened keys
"""
merged: dict[str, Any] = {}
for key, value in data.items():
# Split off the base name (before the first delimiter, if any)
delim, pos = _next_delim(key)
if delim is None:
merged[key] = value
continue
base, path = key[:pos], key[pos:]
tokens = _tokenise(path)
if tokens is None:
# Invalid key; treat as scalar under the raw name
merged[key] = value
continue
merged[base] = _assign(merged.get(base), tokens, value)
data.update(merged)
return data

View File

@@ -38,8 +38,8 @@ from prisma.types import (
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import type as type_utils
from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.retry import func_retry
@@ -478,48 +478,6 @@ async def get_graph_executions(
return [GraphExecutionMeta.from_db(execution) for execution in executions]
async def get_graph_executions_count(
user_id: Optional[str] = None,
graph_id: Optional[str] = None,
statuses: Optional[list[ExecutionStatus]] = None,
created_time_gte: Optional[datetime] = None,
created_time_lte: Optional[datetime] = None,
) -> int:
"""
Get count of graph executions with optional filters.
Args:
user_id: Optional user ID to filter by
graph_id: Optional graph ID to filter by
statuses: Optional list of execution statuses to filter by
created_time_gte: Optional minimum creation time
created_time_lte: Optional maximum creation time
Returns:
Count of matching graph executions
"""
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if user_id:
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
if created_time_gte or created_time_lte:
where_filter["createdAt"] = {
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
}
if statuses:
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
count = await AgentGraphExecution.prisma().count(where=where_filter)
return count
class GraphExecutionsPaginated(BaseModel):
"""Response schema for paginated graph executions."""

View File

@@ -7,7 +7,7 @@ from prisma.enums import AgentExecutionStatus
from backend.data.execution import get_graph_executions
from backend.data.graph import get_graph_metadata
from backend.data.model import UserExecutionSummaryStats
from backend.util.exceptions import DatabaseError
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[SummaryData]")

View File

@@ -20,8 +20,6 @@ from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import LlmModel
from backend.data.db import prisma as db
from backend.data.dynamic_fields import extract_base_field_name
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
from backend.data.model import (
CredentialsField,
CredentialsFieldInfo,
@@ -33,15 +31,7 @@ from backend.util import type as type_utils
from backend.util.json import SafeJson
from backend.util.models import Pagination
from .block import (
Block,
BlockInput,
BlockSchema,
BlockType,
EmptySchema,
get_block,
get_blocks,
)
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .db import BaseDbModel, query_raw_with_schema, transaction
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
@@ -82,15 +72,12 @@ class Node(BaseDbModel):
output_links: list[Link] = []
@property
def block(self) -> "Block[BlockSchema, BlockSchema] | _UnknownBlockBase":
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
def block(self) -> Block[BlockSchema, BlockSchema]:
block = get_block(self.block_id)
if not block:
# Log warning but don't raise exception - return a placeholder block for deleted blocks
logger.warning(
f"Block #{self.block_id} does not exist for Node #{self.id} (deleted/missing block), using UnknownBlock"
raise ValueError(
f"Block #{self.block_id} does not exist -> Node #{self.id} is invalid"
)
return _UnknownBlockBase(self.block_id)
return block
@@ -129,20 +116,17 @@ class NodeModel(Node):
Returns a copy of the node model, stripped of any non-transferable properties
"""
stripped_node = self.model_copy(deep=True)
# Remove credentials and other (possible) secrets from node input
# Remove credentials from node input
if stripped_node.input_default:
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
stripped_node.input_default, self.block.input_schema.jsonschema()
)
# Remove default secret value from secret input nodes
if (
stripped_node.block.block_type == BlockType.INPUT
and stripped_node.input_default.get("secret", False) is True
and "value" in stripped_node.input_default
):
del stripped_node.input_default["value"]
stripped_node.input_default["value"] = ""
# Remove webhook info
stripped_node.webhook_id = None
@@ -159,10 +143,8 @@ class NodeModel(Node):
result = {}
for key, value in input_data.items():
field_schema: dict | None = field_schemas.get(key)
if (field_schema and field_schema.get("secret", False)) or (
any(sensitive_key in key.lower() for sensitive_key in sensitive_keys)
# Prevent removing `secret` flag on input nodes
and type(value) is not bool
if (field_schema and field_schema.get("secret", False)) or any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
):
# This is a secret value -> filter this key-value pair out
continue
@@ -747,7 +729,7 @@ def _is_tool_pin(name: str) -> bool:
def _sanitize_pin_name(name: str) -> str:
sanitized_name = extract_base_field_name(name)
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if _is_tool_pin(sanitized_name):
return "tools"
return sanitized_name
@@ -833,9 +815,9 @@ async def list_graphs_paginated(
where=where_clause,
distinct=["id"],
order={"version": "desc"},
include=AGENT_GRAPH_INCLUDE,
skip=offset,
take=page_size,
# Don't include nodes for list endpoint - GraphMeta excludes them anyway
)
graph_models: list[GraphMeta] = []
@@ -1077,14 +1059,11 @@ async def set_graph_active_version(graph_id: str, version: int, user_id: str) ->
)
async def get_graph_all_versions(
graph_id: str, user_id: str, limit: int = MAX_GRAPH_VERSIONS_FETCH
) -> list[GraphModel]:
async def get_graph_all_versions(graph_id: str, user_id: str) -> list[GraphModel]:
graph_versions = await AgentGraph.prisma().find_many(
where={"id": graph_id, "userId": user_id},
order={"version": "desc"},
include=AGENT_GRAPH_INCLUDE,
take=limit,
)
if not graph_versions:
@@ -1333,34 +1312,3 @@ async def migrate_llm_models(migrate_to: LlmModel):
id,
path,
)
# Simple placeholder class for deleted/missing blocks
class _UnknownBlockBase(Block):
"""
Placeholder for deleted/missing blocks that inherits from Block
but uses a name that doesn't end with 'Block' to avoid auto-discovery.
"""
def __init__(self, block_id: str = "00000000-0000-0000-0000-000000000000"):
# Initialize with minimal valid Block parameters
super().__init__(
id=block_id,
description=f"Unknown or deleted block (original ID: {block_id})",
disabled=True,
input_schema=EmptySchema,
output_schema=EmptySchema,
categories=set(),
contributors=[],
static_output=False,
block_type=BlockType.STANDARD,
webhook_config=None,
)
@property
def name(self):
return "UnknownBlock"
async def run(self, input_data, **kwargs):
"""Always yield an error for missing blocks."""
yield "error", f"Block {self.id} no longer exists"

View File

@@ -201,56 +201,25 @@ async def test_get_input_schema(server: SpinTestServer, snapshot: Snapshot):
@pytest.mark.asyncio(loop_scope="session")
async def test_clean_graph(server: SpinTestServer):
"""
Test the stripped_for_export function that:
1. Removes sensitive/secret fields from node inputs
2. Removes webhook information
3. Preserves non-sensitive data including input block values
Test the clean_graph function that:
1. Clears input block values
2. Removes credentials from nodes
"""
# Create a graph with input blocks containing both sensitive and normal data
# Create a graph with input blocks and credentials
graph = Graph(
id="test_clean_graph",
name="Test Clean Graph",
description="Test graph cleaning",
nodes=[
Node(
id="input_node",
block_id=AgentInputBlock().id,
input_default={
"_test_id": "input_node",
"name": "test_input",
"value": "test value", # This should be preserved
"value": "test value",
"description": "Test input description",
},
),
Node(
block_id=AgentInputBlock().id,
input_default={
"_test_id": "input_node_secret",
"name": "secret_input",
"value": "another value",
"secret": True, # This makes the input secret
},
),
Node(
block_id=StoreValueBlock().id,
input_default={
"_test_id": "node_with_secrets",
"input": "normal_value",
"control_test_input": "should be preserved",
"api_key": "secret_api_key_123", # Should be filtered
"password": "secret_password_456", # Should be filtered
"token": "secret_token_789", # Should be filtered
"credentials": { # Should be filtered
"id": "fake-github-credentials-id",
"provider": "github",
"type": "api_key",
},
"anthropic_credentials": { # Should be filtered
"id": "fake-anthropic-credentials-id",
"provider": "anthropic",
"type": "api_key",
},
},
),
],
links=[],
)
@@ -262,54 +231,15 @@ async def test_clean_graph(server: SpinTestServer):
)
# Clean the graph
cleaned_graph = await server.agent_server.test_get_graph(
created_graph = await server.agent_server.test_get_graph(
created_graph.id, created_graph.version, DEFAULT_USER_ID, for_export=True
)
# Verify sensitive fields are removed but normal fields are preserved
# # Verify input block value is cleared
input_node = next(
n for n in cleaned_graph.nodes if n.input_default["_test_id"] == "input_node"
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
)
# Non-sensitive fields should be preserved
assert input_node.input_default["name"] == "test_input"
assert input_node.input_default["value"] == "test value" # Should be preserved now
assert input_node.input_default["description"] == "Test input description"
# Sensitive fields should be filtered out
assert "api_key" not in input_node.input_default
assert "password" not in input_node.input_default
# Verify secret input node preserves non-sensitive fields but removes secret value
secret_node = next(
n
for n in cleaned_graph.nodes
if n.input_default["_test_id"] == "input_node_secret"
)
assert secret_node.input_default["name"] == "secret_input"
assert "value" not in secret_node.input_default # Secret default should be removed
assert secret_node.input_default["secret"] is True
# Verify sensitive fields are filtered from nodes with secrets
secrets_node = next(
n
for n in cleaned_graph.nodes
if n.input_default["_test_id"] == "node_with_secrets"
)
# Normal fields should be preserved
assert secrets_node.input_default["input"] == "normal_value"
assert secrets_node.input_default["control_test_input"] == "should be preserved"
# Sensitive fields should be filtered out
assert "api_key" not in secrets_node.input_default
assert "password" not in secrets_node.input_default
assert "token" not in secrets_node.input_default
assert "credentials" not in secrets_node.input_default
assert "anthropic_credentials" not in secrets_node.input_default
# Verify webhook info is removed (if any nodes had it)
for node in cleaned_graph.nodes:
assert node.webhook_id is None
assert node.webhook is None
assert input_node.input_default["value"] == ""
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -14,7 +14,6 @@ AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
"Nodes": {"include": AGENT_NODE_INCLUDE}
}
EXECUTION_RESULT_ORDER: list[prisma.types.AgentNodeExecutionOrderByInput] = [
{"queuedTime": "desc"},
# Fallback: Incomplete execs has no queuedTime.
@@ -29,13 +28,6 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
}
MAX_NODE_EXECUTIONS_FETCH = 1000
MAX_LIBRARY_AGENT_EXECUTIONS_FETCH = 10
# Default limits for potentially large result sets
MAX_CREDIT_REFUND_REQUESTS_FETCH = 100
MAX_INTEGRATION_WEBHOOKS_FETCH = 100
MAX_USER_API_KEYS_FETCH = 500
MAX_GRAPH_VERSIONS_FETCH = 50
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
"NodeExecutions": {
@@ -79,56 +71,13 @@ INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
}
def library_agent_include(
user_id: str,
include_nodes: bool = True,
include_executions: bool = True,
execution_limit: int = MAX_LIBRARY_AGENT_EXECUTIONS_FETCH,
) -> prisma.types.LibraryAgentInclude:
"""
Fully configurable includes for library agent queries with performance optimization.
Args:
user_id: User ID for filtering user-specific data
include_nodes: Whether to include graph nodes (default: True, needed for get_sub_graphs)
include_executions: Whether to include executions (default: True, safe with execution_limit)
execution_limit: Limit on executions to fetch (default: MAX_LIBRARY_AGENT_EXECUTIONS_FETCH)
Defaults maintain backward compatibility and safety - includes everything needed for all functionality.
For performance optimization, explicitly set include_nodes=False and include_executions=False
for listing views where frontend fetches data separately.
Performance impact:
- Default (full nodes + limited executions): Original performance, works everywhere
- Listing optimization (no nodes/executions): ~2s for 15 agents vs potential timeouts
- Unlimited executions: varies by user (thousands of executions = timeouts)
"""
result: prisma.types.LibraryAgentInclude = {
"Creator": True, # Always needed for creator info
}
# Build AgentGraph include based on requested options
if include_nodes or include_executions:
agent_graph_include = {}
# Add nodes if requested (always full nodes)
if include_nodes:
agent_graph_include.update(AGENT_GRAPH_INCLUDE) # Full nodes
# Add executions if requested
if include_executions:
agent_graph_include["Executions"] = {
"where": {"userId": user_id},
"order_by": {"createdAt": "desc"},
"take": execution_limit,
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
return {
"AgentGraph": {
"include": {
**AGENT_GRAPH_INCLUDE,
"Executions": {"where": {"userId": user_id}},
}
result["AgentGraph"] = cast(
prisma.types.AgentGraphArgsFromLibraryAgent,
{"include": agent_graph_include},
)
else:
# Default: Basic metadata only (fast - recommended for most use cases)
result["AgentGraph"] = True # Basic graph metadata (name, description, id)
return result
},
"Creator": True,
}

View File

@@ -11,10 +11,7 @@ from prisma.types import (
from pydantic import Field, computed_field
from backend.data.event_bus import AsyncRedisEventBus
from backend.data.includes import (
INTEGRATION_WEBHOOK_INCLUDE,
MAX_INTEGRATION_WEBHOOKS_FETCH,
)
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from backend.server.v2.library.model import LibraryAgentPreset
@@ -131,36 +128,22 @@ async def get_webhook(
@overload
async def get_all_webhooks_by_creds(
user_id: str,
credentials_id: str,
*,
include_relations: Literal[True],
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
user_id: str, credentials_id: str, *, include_relations: Literal[True]
) -> list[WebhookWithRelations]: ...
@overload
async def get_all_webhooks_by_creds(
user_id: str,
credentials_id: str,
*,
include_relations: Literal[False] = False,
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
) -> list[Webhook]: ...
async def get_all_webhooks_by_creds(
user_id: str,
credentials_id: str,
*,
include_relations: bool = False,
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
user_id: str, credentials_id: str, *, include_relations: bool = False
) -> list[Webhook] | list[WebhookWithRelations]:
if not credentials_id:
raise ValueError("credentials_id must not be empty")
webhooks = await IntegrationWebhook.prisma().find_many(
where={"userId": user_id, "credentialsId": credentials_id},
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
order={"createdAt": "desc"},
take=limit,
)
return [
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)

View File

@@ -270,7 +270,6 @@ def SchemaField(
min_length: Optional[int] = None,
max_length: Optional[int] = None,
discriminator: Optional[str] = None,
format: Optional[str] = None,
json_schema_extra: Optional[dict[str, Any]] = None,
) -> T:
if default is PydanticUndefined and default_factory is None:
@@ -286,7 +285,6 @@ def SchemaField(
"advanced": advanced,
"hidden": hidden,
"depends_on": depends_on,
"format": format,
**(json_schema_extra or {}),
}.items()
if v is not None
@@ -347,9 +345,6 @@ class APIKeyCredentials(_BaseCredentials):
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
def auth_header(self) -> str:
# Linear API keys should not have Bearer prefix
if self.provider == "linear":
return self.api_key.get_secret_value()
return f"Bearer {self.api_key.get_secret_value()}"

View File

@@ -15,7 +15,7 @@ from prisma.types import (
# from backend.notifications.models import NotificationEvent
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from backend.util.exceptions import DatabaseError
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.json import SafeJson
from backend.util.logging import TruncatedLogger
@@ -235,7 +235,6 @@ class BaseEventModel(BaseModel):
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
id: Optional[str] = None # None when creating, populated when reading from DB
data: NotificationDataType_co
@property
@@ -379,7 +378,6 @@ class NotificationPreference(BaseModel):
class UserNotificationEventDTO(BaseModel):
id: str # Added to track notifications for removal
type: NotificationType
data: dict
created_at: datetime
@@ -388,7 +386,6 @@ class UserNotificationEventDTO(BaseModel):
@staticmethod
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
return UserNotificationEventDTO(
id=model.id,
type=model.type,
data=dict(model.data),
created_at=model.createdAt,
@@ -544,79 +541,6 @@ async def empty_user_notification_batch(
) from e
async def clear_all_user_notification_batches(user_id: str) -> None:
"""Clear ALL notification batches for a user across all types.
Used when user's email is bounced/inactive and we should stop
trying to send them ANY emails.
"""
try:
# Delete all notification events for this user
await NotificationEvent.prisma().delete_many(
where={"UserNotificationBatch": {"is": {"userId": user_id}}}
)
# Delete all batches for this user
await UserNotificationBatch.prisma().delete_many(where={"userId": user_id})
logger.info(f"Cleared all notification batches for user {user_id}")
except Exception as e:
raise DatabaseError(
f"Failed to clear all notification batches for user {user_id}: {e}"
) from e
async def remove_notifications_from_batch(
user_id: str, notification_type: NotificationType, notification_ids: list[str]
) -> None:
"""Remove specific notifications from a user's batch by their IDs.
This is used after successful sending to remove only the
sent notifications, preventing duplicates on retry.
"""
if not notification_ids:
return
try:
# Delete the specific notification events
deleted_count = await NotificationEvent.prisma().delete_many(
where={
"id": {"in": notification_ids},
"UserNotificationBatch": {
"is": {"userId": user_id, "type": notification_type}
},
}
)
logger.info(
f"Removed {deleted_count} notifications from batch for user {user_id}"
)
# Check if batch is now empty and delete it if so
remaining = await NotificationEvent.prisma().count(
where={
"UserNotificationBatch": {
"is": {"userId": user_id, "type": notification_type}
}
}
)
if remaining == 0:
await UserNotificationBatch.prisma().delete_many(
where=UserNotificationBatchWhereInput(
userId=user_id,
type=notification_type,
)
)
logger.info(
f"Deleted empty batch for user {user_id} and type {notification_type}"
)
except Exception as e:
raise DatabaseError(
f"Failed to remove notifications from batch for user {user_id} and type {notification_type}: {e}"
) from e
async def get_user_notification_batch(
user_id: str,
notification_type: NotificationType,

View File

@@ -1,5 +1,4 @@
import re
from datetime import datetime
from typing import Any, Optional
import prisma
@@ -10,9 +9,9 @@ from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.server.v2.store.model import StoreAgentDetails
from backend.util.cache import cached
from backend.util.json import SafeJson
# Mapping from user reason id to categories to search for when choosing agent to show
@@ -26,10 +25,12 @@ REASON_MAPPING: dict[str, list[str]] = {
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
user_credit = get_user_credit_model()
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
walletShown: Optional[bool] = None
notificationDot: Optional[bool] = None
notified: Optional[list[OnboardingStep]] = None
usageReason: Optional[str] = None
integrations: Optional[list[str]] = None
@@ -38,8 +39,6 @@ class UserOnboardingUpdate(pydantic.BaseModel):
agentInput: Optional[dict[str, Any]] = None
onboardingAgentExecutionId: Optional[str] = None
agentRuns: Optional[int] = None
lastRunAt: Optional[datetime] = None
consecutiveRunDays: Optional[int] = None
async def get_user_onboarding(user_id: str):
@@ -58,22 +57,16 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update["completedSteps"] = list(set(data.completedSteps))
for step in (
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.MARKETPLACE_VISIT,
OnboardingStep.RUN_AGENTS,
OnboardingStep.MARKETPLACE_ADD_AGENT,
OnboardingStep.MARKETPLACE_RUN_AGENT,
OnboardingStep.BUILDER_SAVE_AGENT,
OnboardingStep.RE_RUN_AGENT,
OnboardingStep.SCHEDULE_AGENT,
OnboardingStep.RUN_AGENTS,
OnboardingStep.RUN_3_DAYS,
OnboardingStep.TRIGGER_WEBHOOK,
OnboardingStep.RUN_14_DAYS,
OnboardingStep.RUN_AGENTS_100,
OnboardingStep.BUILDER_RUN_AGENT,
):
if step in data.completedSteps:
await reward_user(user_id, step)
if data.walletShown is not None:
update["walletShown"] = data.walletShown
if data.notificationDot is not None:
update["notificationDot"] = data.notificationDot
if data.notified is not None:
update["notified"] = list(set(data.notified))
if data.usageReason is not None:
@@ -90,10 +83,6 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
if data.agentRuns is not None:
update["agentRuns"] = data.agentRuns
if data.lastRunAt is not None:
update["lastRunAt"] = data.lastRunAt
if data.consecutiveRunDays is not None:
update["consecutiveRunDays"] = data.consecutiveRunDays
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
@@ -112,28 +101,16 @@ async def reward_user(user_id: str, step: OnboardingStep):
# This is seen as a reward for the GET_RESULTS step in the wallet
case OnboardingStep.AGENT_NEW_RUN:
reward = 300
case OnboardingStep.MARKETPLACE_VISIT:
reward = 100
case OnboardingStep.RUN_AGENTS:
reward = 300
case OnboardingStep.MARKETPLACE_ADD_AGENT:
reward = 100
case OnboardingStep.MARKETPLACE_RUN_AGENT:
reward = 100
case OnboardingStep.BUILDER_SAVE_AGENT:
reward = 100
case OnboardingStep.RE_RUN_AGENT:
case OnboardingStep.BUILDER_RUN_AGENT:
reward = 100
case OnboardingStep.SCHEDULE_AGENT:
reward = 100
case OnboardingStep.RUN_AGENTS:
reward = 300
case OnboardingStep.RUN_3_DAYS:
reward = 100
case OnboardingStep.TRIGGER_WEBHOOK:
reward = 100
case OnboardingStep.RUN_14_DAYS:
reward = 300
case OnboardingStep.RUN_AGENTS_100:
reward = 300
if reward == 0:
return
@@ -145,8 +122,7 @@ async def reward_user(user_id: str, step: OnboardingStep):
return
onboarding.rewardedFor.append(step)
user_credit_model = await get_user_credit_model(user_id)
await user_credit_model.onboarding_reward(user_id, reward, step)
await user_credit.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
@@ -156,22 +132,6 @@ async def reward_user(user_id: str, step: OnboardingStep):
)
async def complete_webhook_trigger_step(user_id: str):
"""
Completes the TRIGGER_WEBHOOK onboarding step for the user if not already completed.
"""
onboarding = await get_user_onboarding(user_id)
if OnboardingStep.TRIGGER_WEBHOOK not in onboarding.completedSteps:
await update_user_onboarding(
user_id,
UserOnboardingUpdate(
completedSteps=onboarding.completedSteps
+ [OnboardingStep.TRIGGER_WEBHOOK]
),
)
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,
@@ -276,14 +236,8 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
for word in user_onboarding.integrations
]
where_clause["is_available"] = True
# Try to take only agents that are available and allowed for onboarding
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where={
"is_available": True,
"useForOnboarding": True,
},
where=prisma.types.StoreAgentWhereInput(**where_clause),
order=[
{"featured": "desc"},
{"runs": "desc"},
@@ -292,16 +246,59 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
take=100,
)
# If not enough agents found, relax the useForOnboarding filter
agentListings = await prisma.models.StoreListingVersion.prisma().find_many(
where={
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
},
include={"AgentGraph": True},
)
for listing in agentListings:
agent = listing.AgentGraph
if agent is None:
continue
graph = GraphModel.from_db(agent)
# Remove agents with empty input schema
if not graph.input_schema:
storeAgents = [
a for a in storeAgents if a.storeListingVersionId != listing.id
]
continue
# Remove agents with empty credentials
# Get nodes from this agent that have credentials
nodes = await prisma.models.AgentNode.prisma().find_many(
where={
"agentGraphId": agent.id,
"agentBlockId": {"in": list(CREDENTIALS_FIELDS.keys())},
},
)
for node in nodes:
block_id = node.agentBlockId
field_name = CREDENTIALS_FIELDS[block_id]
# If there are no credentials or they are empty, remove the agent
# FIXME ignores default values
if (
field_name not in node.constantInput
or node.constantInput[field_name] is None
):
storeAgents = [
a for a in storeAgents if a.storeListingVersionId != listing.id
]
break
# If there are less than 2 agents, add more agents to the list
if len(storeAgents) < 2:
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where=prisma.types.StoreAgentWhereInput(**where_clause),
storeAgents += await prisma.models.StoreAgent.prisma().find_many(
where={
"listing_id": {"not_in": [agent.listing_id for agent in storeAgents]},
},
order=[
{"featured": "desc"},
{"runs": "desc"},
{"rating": "desc"},
],
take=100,
take=2 - len(storeAgents),
)
# Calculate points for the first X agents and choose the top 2
@@ -336,13 +333,8 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
]
@cached(maxsize=1, ttl_seconds=300) # Cache for 5 minutes since this rarely changes
async def onboarding_enabled() -> bool:
"""
Check if onboarding should be enabled based on store agent count.
Cached to prevent repeated slow database queries.
"""
# Use a more efficient query that stops counting after finding enough agents
count = await prisma.models.StoreAgent.prisma().count(take=MIN_AGENT_COUNT + 1)
# Onboarding is enabled if there are at least 2 agents in the store
# Onboading is enabled if there are at least 2 agents in the store
return count >= MIN_AGENT_COUNT

View File

@@ -1,5 +0,0 @@
import prisma.models
class StoreAgentWithRank(prisma.models.StoreAgent):
rank: float

View File

@@ -1,11 +1,11 @@
import logging
import os
from autogpt_libs.utils.cache import cached, thread_cached
from dotenv import load_dotenv
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from backend.util.cache import cached, thread_cached
from backend.util.retry import conn_retry
load_dotenv()
@@ -34,7 +34,7 @@ def disconnect():
get_redis().close()
@cached(ttl_seconds=3600)
@cached()
def get_redis() -> Redis:
return connect()

View File

@@ -7,6 +7,7 @@ from typing import Optional, cast
from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
from autogpt_libs.utils.cache import cached
from fastapi import HTTPException
from prisma.enums import NotificationType
from prisma.models import User as PrismaUser
@@ -15,9 +16,8 @@ from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
from backend.data.db import prisma
from backend.data.model import User, UserIntegrations, UserMetadata
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.util.cache import cached
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.encryption import JSONCryptor
from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
from backend.util.settings import Settings
@@ -354,36 +354,6 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
) from e
async def disable_all_user_notifications(user_id: str) -> None:
"""Disable all notification preferences for a user.
Used when user's email bounces/is inactive to prevent any future notifications.
"""
try:
await PrismaUser.prisma().update(
where={"id": user_id},
data={
"notifyOnAgentRun": False,
"notifyOnZeroBalance": False,
"notifyOnLowBalance": False,
"notifyOnBlockExecutionFailed": False,
"notifyOnContinuousAgentError": False,
"notifyOnDailySummary": False,
"notifyOnWeeklySummary": False,
"notifyOnMonthlySummary": False,
"notifyOnAgentApproved": False,
"notifyOnAgentRejected": False,
},
)
# Invalidate cache for this user
get_user_by_id.cache_delete(user_id)
logger.info(f"Disabled all notification preferences for user {user_id}")
except Exception as e:
raise DatabaseError(
f"Failed to disable notifications for user {user_id}: {e}"
) from e
async def get_user_email_verification(user_id: str) -> bool:
"""Get the email verification status for a user."""
try:

View File

@@ -4,12 +4,7 @@ Module for generating AI-based activity status for graph executions.
import json
import logging
from typing import TYPE_CHECKING, Any, TypedDict
try:
from typing import NotRequired
except ImportError:
from typing_extensions import NotRequired
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
from pydantic import SecretStr
@@ -151,35 +146,17 @@ async def generate_activity_status_for_execution(
"Focus on the ACTUAL TASK the user wanted done, not the internal workflow steps. "
"Avoid technical terms like 'workflow', 'execution', 'components', 'nodes', 'processing', etc. "
"Keep it to 3 sentences maximum. Be conversational and human-friendly.\n\n"
"UNDERSTAND THE INTENDED PURPOSE:\n"
"- FIRST: Read the graph description carefully to understand what the user wanted to accomplish\n"
"- The graph name and description tell you the main goal/intention of this automation\n"
"- Use this intended purpose as your PRIMARY criteria for success/failure evaluation\n"
"- Ask yourself: 'Did this execution actually accomplish what the graph was designed to do?'\n\n"
"CRITICAL OUTPUT ANALYSIS:\n"
"- Check if blocks that should produce user-facing results actually produced outputs\n"
"- Blocks with names containing 'Output', 'Post', 'Create', 'Send', 'Publish', 'Generate' are usually meant to produce final results\n"
"- If these critical blocks have NO outputs (empty recent_outputs), the task likely FAILED even if status shows 'completed'\n"
"- Sub-agents (AgentExecutorBlock) that produce no outputs usually indicate failed sub-tasks\n"
"- Most importantly: Does the execution result match what the graph description promised to deliver?\n\n"
"SUCCESS EVALUATION BASED ON INTENTION:\n"
"- If the graph is meant to 'create blog posts' → check if blog content was actually created\n"
"- If the graph is meant to 'send emails' → check if emails were actually sent\n"
"- If the graph is meant to 'analyze data' → check if analysis results were produced\n"
"- If the graph is meant to 'generate reports' → check if reports were generated\n"
"- Technical completion ≠ goal achievement. Focus on whether the USER'S INTENDED OUTCOME was delivered\n\n"
"IMPORTANT: Be HONEST about what actually happened:\n"
"- If the input was invalid/nonsensical, say so directly\n"
"- If the task failed, explain what went wrong in simple terms\n"
"- If errors occurred, focus on what the user needs to know\n"
"- Only claim success if the INTENDED PURPOSE was genuinely accomplished AND produced expected outputs\n"
"- Don't sugar-coat failures or present them as helpful feedback\n"
"- ESPECIALLY: If the graph's main purpose wasn't achieved, this is a failure regardless of 'completed' status\n\n"
"- Only claim success if the task was genuinely completed\n"
"- Don't sugar-coat failures or present them as helpful feedback\n\n"
"Understanding Errors:\n"
"- Node errors: Individual steps may fail but the overall task might still complete (e.g., one data source fails but others work)\n"
"- Graph error (in overall_status.graph_error): This means the entire execution failed and nothing was accomplished\n"
"- Missing outputs from critical blocks: Even if no errors, this means the task failed to produce expected results\n"
"- Focus on whether the graph's intended purpose was fulfilled, not whether technical steps completed"
"- Even if execution shows 'completed', check if critical nodes failed that would prevent the desired outcome\n"
"- Focus on the end result the user wanted, not whether technical steps completed"
),
},
{
@@ -188,28 +165,15 @@ async def generate_activity_status_for_execution(
f"A user ran '{graph_name}' to accomplish something. Based on this execution data, "
f"write what they achieved in simple, user-friendly terms:\n\n"
f"{json.dumps(execution_data, indent=2)}\n\n"
"ANALYSIS CHECKLIST:\n"
"1. READ graph_info.description FIRST - this tells you what the user intended to accomplish\n"
"2. Check overall_status.graph_error - if present, the entire execution failed\n"
"3. Look for nodes with 'Output', 'Post', 'Create', 'Send', 'Publish', 'Generate' in their block_name\n"
"4. Check if these critical blocks have empty recent_outputs arrays - this indicates failure\n"
"5. Look for AgentExecutorBlock (sub-agents) with no outputs - this suggests sub-task failures\n"
"6. Count how many nodes produced outputs vs total nodes - low ratio suggests problems\n"
"7. MOST IMPORTANT: Does the execution outcome match what graph_info.description promised?\n\n"
"INTENTION-BASED EVALUATION:\n"
"- If description mentions 'blog writing' → did it create blog content?\n"
"- If description mentions 'email automation' → were emails actually sent?\n"
"- If description mentions 'data analysis' → were analysis results produced?\n"
"- If description mentions 'content generation' → was content actually generated?\n"
"- If description mentions 'social media posting' → were posts actually made?\n"
"- Match the outputs to the stated intention, not just technical completion\n\n"
"CRITICAL: Check overall_status.graph_error FIRST - if present, the entire execution failed.\n"
"Then check individual node errors to understand partial failures.\n\n"
"Write 1-3 sentences about what the user accomplished, such as:\n"
"- 'I analyzed your resume and provided detailed feedback for the IT industry.'\n"
"- 'I couldn't complete the task because critical steps failed to produce any results.'\n"
"- 'I failed to generate the content you requested due to missing API access.'\n"
"- 'I couldn't analyze your resume because the input was just nonsensical text.'\n"
"- 'I failed to complete the task due to missing API access.'\n"
"- 'I extracted key information from your documents and organized it into a summary.'\n"
"- 'The task failed because the blog post creation step didn't produce any output.'\n\n"
"BE CRITICAL: If the graph's intended purpose (from description) wasn't achieved, report this as a failure even if status is 'completed'."
"- 'The task failed to run due to system configuration issues.'\n\n"
"Focus on what ACTUALLY happened, not what was attempted."
),
},
]
@@ -233,7 +197,6 @@ async def generate_activity_status_for_execution(
logger.debug(
f"Generated activity status for {graph_exec_id}: {activity_status}"
)
return activity_status
except Exception as e:

View File

@@ -1,115 +0,0 @@
"""Redis-based distributed locking for cluster coordination."""
import logging
import time
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from redis import Redis
logger = logging.getLogger(__name__)
class ClusterLock:
"""Simple Redis-based distributed lock for preventing duplicate execution."""
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
self.redis = redis
self.key = key
self.owner_id = owner_id
self.timeout = timeout
self._last_refresh = 0.0
def try_acquire(self) -> str | None:
"""Try to acquire the lock.
Returns:
- owner_id (self.owner_id) if successfully acquired
- different owner_id if someone else holds the lock
- None if Redis is unavailable or other error
"""
try:
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
if success:
self._last_refresh = time.time()
return self.owner_id # Successfully acquired
# Failed to acquire, get current owner
current_value = self.redis.get(self.key)
if current_value:
current_owner = (
current_value.decode("utf-8")
if isinstance(current_value, bytes)
else str(current_value)
)
return current_owner
# Key doesn't exist but we failed to set it - race condition or Redis issue
return None
except Exception as e:
logger.error(f"ClusterLock.try_acquire failed for key {self.key}: {e}")
return None
def refresh(self) -> bool:
"""Refresh lock TTL if we still own it.
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
During rate limiting, still verifies lock existence but skips TTL extension.
Setting _last_refresh to 0 bypasses rate limiting for testing.
"""
# Calculate refresh interval: max(timeout // 10, 1)
refresh_interval = max(self.timeout // 10, 1)
current_time = time.time()
# Check if we're within the rate limit period
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
is_rate_limited = (
self._last_refresh > 0
and (current_time - self._last_refresh) < refresh_interval
)
try:
# Always verify lock existence, even during rate limiting
current_value = self.redis.get(self.key)
if not current_value:
self._last_refresh = 0
return False
stored_owner = (
current_value.decode("utf-8")
if isinstance(current_value, bytes)
else str(current_value)
)
if stored_owner != self.owner_id:
self._last_refresh = 0
return False
# If rate limited, return True but don't update TTL or timestamp
if is_rate_limited:
return True
# Perform actual refresh
if self.redis.expire(self.key, self.timeout):
self._last_refresh = current_time
return True
self._last_refresh = 0
return False
except Exception as e:
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
self._last_refresh = 0
return False
def release(self):
"""Release the lock."""
if self._last_refresh == 0:
return
try:
self.redis.delete(self.key)
except Exception:
pass
self._last_refresh = 0.0

View File

@@ -1,507 +0,0 @@
"""
Integration tests for ClusterLock - Redis-based distributed locking.
Tests the complete lock lifecycle without mocking Redis to ensure
real-world behavior is correct. Covers acquisition, refresh, expiry,
contention, and error scenarios.
"""
import logging
import time
import uuid
from threading import Thread
import pytest
import redis
from .cluster_lock import ClusterLock
logger = logging.getLogger(__name__)
@pytest.fixture
def redis_client():
"""Get Redis client for testing using same config as backend."""
from backend.data.redis_client import HOST, PASSWORD, PORT
# Use same config as backend but without decode_responses since ClusterLock needs raw bytes
client = redis.Redis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=False, # ClusterLock needs raw bytes for ownership verification
)
# Clean up any existing test keys
try:
for key in client.scan_iter(match="test_lock:*"):
client.delete(key)
except Exception:
pass # Ignore cleanup errors
return client
@pytest.fixture
def lock_key():
"""Generate unique lock key for each test."""
return f"test_lock:{uuid.uuid4()}"
@pytest.fixture
def owner_id():
"""Generate unique owner ID for each test."""
return str(uuid.uuid4())
class TestClusterLockBasic:
"""Basic lock acquisition and release functionality."""
def test_lock_acquisition_success(self, redis_client, lock_key, owner_id):
"""Test basic lock acquisition succeeds."""
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
# Lock should be acquired successfully
result = lock.try_acquire()
assert result == owner_id # Returns our owner_id when successfully acquired
assert lock._last_refresh > 0
# Lock key should exist in Redis
assert redis_client.exists(lock_key) == 1
assert redis_client.get(lock_key).decode("utf-8") == owner_id
def test_lock_acquisition_contention(self, redis_client, lock_key):
"""Test second acquisition fails when lock is held."""
owner1 = str(uuid.uuid4())
owner2 = str(uuid.uuid4())
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=60)
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=60)
# First lock should succeed
result1 = lock1.try_acquire()
assert result1 == owner1 # Successfully acquired, returns our owner_id
# Second lock should fail and return the first owner
result2 = lock2.try_acquire()
assert result2 == owner1 # Returns the current owner (first owner)
assert lock2._last_refresh == 0
def test_lock_release_deletes_redis_key(self, redis_client, lock_key, owner_id):
"""Test lock release deletes Redis key and marks locally as released."""
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
lock.try_acquire()
assert lock._last_refresh > 0
assert redis_client.exists(lock_key) == 1
# Release should delete Redis key and mark locally as released
lock.release()
assert lock._last_refresh == 0
assert lock._last_refresh == 0.0
# Redis key should be deleted for immediate release
assert redis_client.exists(lock_key) == 0
# Another lock should be able to acquire immediately
new_owner_id = str(uuid.uuid4())
new_lock = ClusterLock(redis_client, lock_key, new_owner_id, timeout=60)
assert new_lock.try_acquire() == new_owner_id
class TestClusterLockRefresh:
"""Lock refresh and TTL management."""
def test_lock_refresh_success(self, redis_client, lock_key, owner_id):
"""Test lock refresh extends TTL."""
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
lock.try_acquire()
original_ttl = redis_client.ttl(lock_key)
# Wait a bit then refresh
time.sleep(1)
lock._last_refresh = 0 # Force refresh past rate limit
assert lock.refresh() is True
# TTL should be reset to full timeout (allow for small timing differences)
new_ttl = redis_client.ttl(lock_key)
assert new_ttl >= original_ttl or new_ttl >= 58 # Allow for timing variance
def test_lock_refresh_rate_limiting(self, redis_client, lock_key, owner_id):
"""Test refresh is rate-limited to timeout/10."""
lock = ClusterLock(
redis_client, lock_key, owner_id, timeout=100
) # 100s timeout
lock.try_acquire()
# First refresh should work
assert lock.refresh() is True
first_refresh_time = lock._last_refresh
# Immediate second refresh should be skipped (rate limited) but verify key exists
assert lock.refresh() is True # Returns True but skips actual refresh
assert lock._last_refresh == first_refresh_time # Time unchanged
def test_lock_refresh_verifies_existence_during_rate_limit(
self, redis_client, lock_key, owner_id
):
"""Test refresh verifies lock existence even during rate limiting."""
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=100)
lock.try_acquire()
# Manually delete the key (simulates expiry or external deletion)
redis_client.delete(lock_key)
# Refresh should detect missing key even during rate limit period
assert lock.refresh() is False
assert lock._last_refresh == 0
def test_lock_refresh_ownership_lost(self, redis_client, lock_key, owner_id):
"""Test refresh fails when ownership is lost."""
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
lock.try_acquire()
# Simulate another process taking the lock
different_owner = str(uuid.uuid4())
redis_client.set(lock_key, different_owner, ex=60)
# Force refresh past rate limit and verify it fails
lock._last_refresh = 0 # Force refresh past rate limit
assert lock.refresh() is False
assert lock._last_refresh == 0
def test_lock_refresh_when_not_acquired(self, redis_client, lock_key, owner_id):
"""Test refresh fails when lock was never acquired."""
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
# Refresh without acquiring should fail
assert lock.refresh() is False
class TestClusterLockExpiry:
"""Lock expiry and timeout behavior."""
def test_lock_natural_expiry(self, redis_client, lock_key, owner_id):
"""Test lock expires naturally via Redis TTL."""
lock = ClusterLock(
redis_client, lock_key, owner_id, timeout=2
) # 2 second timeout
lock.try_acquire()
assert redis_client.exists(lock_key) == 1
# Wait for expiry
time.sleep(3)
assert redis_client.exists(lock_key) == 0
# New lock with same key should succeed
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
assert new_lock.try_acquire() == owner_id
def test_lock_refresh_prevents_expiry(self, redis_client, lock_key, owner_id):
"""Test refreshing prevents lock from expiring."""
lock = ClusterLock(
redis_client, lock_key, owner_id, timeout=3
) # 3 second timeout
lock.try_acquire()
# Wait and refresh before expiry
time.sleep(1)
lock._last_refresh = 0 # Force refresh past rate limit
assert lock.refresh() is True
# Wait beyond original timeout
time.sleep(2.5)
assert redis_client.exists(lock_key) == 1 # Should still exist
class TestClusterLockConcurrency:
"""Concurrent access patterns."""
def test_multiple_threads_contention(self, redis_client, lock_key):
"""Test multiple threads competing for same lock."""
num_threads = 5
successful_acquisitions = []
def try_acquire_lock(thread_id):
owner_id = f"thread_{thread_id}"
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
if lock.try_acquire() == owner_id:
successful_acquisitions.append(thread_id)
time.sleep(0.1) # Hold lock briefly
lock.release()
threads = []
for i in range(num_threads):
thread = Thread(target=try_acquire_lock, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Only one thread should have acquired the lock
assert len(successful_acquisitions) == 1
def test_sequential_lock_reuse(self, redis_client, lock_key):
"""Test lock can be reused after natural expiry."""
owners = [str(uuid.uuid4()) for _ in range(3)]
for i, owner_id in enumerate(owners):
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=1) # 1 second
assert lock.try_acquire() == owner_id
time.sleep(1.5) # Wait for expiry
# Verify lock expired
assert redis_client.exists(lock_key) == 0
def test_refresh_during_concurrent_access(self, redis_client, lock_key):
"""Test lock refresh works correctly during concurrent access attempts."""
owner1 = str(uuid.uuid4())
owner2 = str(uuid.uuid4())
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=5)
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=5)
# Thread 1 holds lock and refreshes
assert lock1.try_acquire() == owner1
def refresh_continuously():
for _ in range(10):
lock1._last_refresh = 0 # Force refresh
lock1.refresh()
time.sleep(0.1)
def try_acquire_continuously():
attempts = 0
while attempts < 20:
if lock2.try_acquire() == owner2:
return True
time.sleep(0.1)
attempts += 1
return False
refresh_thread = Thread(target=refresh_continuously)
acquire_thread = Thread(target=try_acquire_continuously)
refresh_thread.start()
acquire_thread.start()
refresh_thread.join()
acquire_thread.join()
# Lock1 should still own the lock due to refreshes
assert lock1._last_refresh > 0
assert lock2._last_refresh == 0
class TestClusterLockErrorHandling:
"""Error handling and edge cases."""
def test_redis_connection_failure_on_acquire(self, lock_key, owner_id):
"""Test graceful handling when Redis is unavailable during acquisition."""
# Use invalid Redis connection
bad_redis = redis.Redis(
host="invalid_host", port=1234, socket_connect_timeout=1
)
lock = ClusterLock(bad_redis, lock_key, owner_id, timeout=60)
# Should return None for Redis connection failures
result = lock.try_acquire()
assert result is None # Returns None when Redis fails
assert lock._last_refresh == 0
def test_redis_connection_failure_on_refresh(
self, redis_client, lock_key, owner_id
):
"""Test graceful handling when Redis fails during refresh."""
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
# Acquire normally
assert lock.try_acquire() == owner_id
# Replace Redis client with failing one
lock.redis = redis.Redis(
host="invalid_host", port=1234, socket_connect_timeout=1
)
# Refresh should fail gracefully
lock._last_refresh = 0 # Force refresh
assert lock.refresh() is False
assert lock._last_refresh == 0
def test_invalid_lock_parameters(self, redis_client):
"""Test validation of lock parameters."""
owner_id = str(uuid.uuid4())
# All parameters are now simple - no validation needed
# Just test basic construction works
lock = ClusterLock(redis_client, "test_key", owner_id, timeout=60)
assert lock.key == "test_key"
assert lock.owner_id == owner_id
assert lock.timeout == 60
def test_refresh_after_redis_key_deleted(self, redis_client, lock_key, owner_id):
"""Test refresh behavior when Redis key is manually deleted."""
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
lock.try_acquire()
# Manually delete the key (simulates external deletion)
redis_client.delete(lock_key)
# Refresh should fail and mark as not acquired
lock._last_refresh = 0 # Force refresh
assert lock.refresh() is False
assert lock._last_refresh == 0
class TestClusterLockDynamicRefreshInterval:
"""Dynamic refresh interval based on timeout."""
def test_refresh_interval_calculation(self, redis_client, lock_key, owner_id):
"""Test refresh interval is calculated as max(timeout/10, 1)."""
test_cases = [
(5, 1), # 5/10 = 0, but minimum is 1
(10, 1), # 10/10 = 1
(30, 3), # 30/10 = 3
(100, 10), # 100/10 = 10
(200, 20), # 200/10 = 20
(1000, 100), # 1000/10 = 100
]
for timeout, expected_interval in test_cases:
lock = ClusterLock(
redis_client, f"{lock_key}_{timeout}", owner_id, timeout=timeout
)
lock.try_acquire()
# Calculate expected interval using same logic as implementation
refresh_interval = max(timeout // 10, 1)
assert refresh_interval == expected_interval
# Test rate limiting works with calculated interval
assert lock.refresh() is True
first_refresh_time = lock._last_refresh
# Sleep less than interval - should be rate limited
time.sleep(0.1)
assert lock.refresh() is True
assert lock._last_refresh == first_refresh_time # No actual refresh
class TestClusterLockRealWorldScenarios:
"""Real-world usage patterns."""
def test_execution_coordination_simulation(self, redis_client):
"""Simulate graph execution coordination across multiple pods."""
graph_exec_id = str(uuid.uuid4())
lock_key = f"execution:{graph_exec_id}"
# Simulate 3 pods trying to execute same graph
pods = [f"pod_{i}" for i in range(3)]
execution_results = {}
def execute_graph(pod_id):
"""Simulate graph execution with cluster lock."""
lock = ClusterLock(redis_client, lock_key, pod_id, timeout=300)
if lock.try_acquire() == pod_id:
# Simulate execution work
execution_results[pod_id] = "executed"
time.sleep(0.1)
lock.release()
else:
execution_results[pod_id] = "rejected"
threads = []
for pod_id in pods:
thread = Thread(target=execute_graph, args=(pod_id,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Only one pod should have executed
executed_count = sum(
1 for result in execution_results.values() if result == "executed"
)
rejected_count = sum(
1 for result in execution_results.values() if result == "rejected"
)
assert executed_count == 1
assert rejected_count == 2
def test_long_running_execution_with_refresh(
self, redis_client, lock_key, owner_id
):
"""Test lock maintains ownership during long execution with periodic refresh."""
lock = ClusterLock(
redis_client, lock_key, owner_id, timeout=30
) # 30 second timeout, refresh interval = max(30//10, 1) = 3 seconds
def long_execution_with_refresh():
"""Simulate long-running execution with periodic refresh."""
assert lock.try_acquire() == owner_id
# Simulate 10 seconds of work with refreshes every 2 seconds
# This respects rate limiting - actual refreshes will happen at 0s, 3s, 6s, 9s
try:
for i in range(5): # 5 iterations * 2 seconds = 10 seconds total
time.sleep(2)
refresh_success = lock.refresh()
assert refresh_success is True, f"Refresh failed at iteration {i}"
return "completed"
finally:
lock.release()
# Should complete successfully without losing lock
result = long_execution_with_refresh()
assert result == "completed"
def test_graceful_degradation_pattern(self, redis_client, lock_key):
"""Test graceful degradation when Redis becomes unavailable."""
owner_id = str(uuid.uuid4())
lock = ClusterLock(
redis_client, lock_key, owner_id, timeout=3
) # Use shorter timeout
# Normal operation
assert lock.try_acquire() == owner_id
lock._last_refresh = 0 # Force refresh past rate limit
assert lock.refresh() is True
# Simulate Redis becoming unavailable
original_redis = lock.redis
lock.redis = redis.Redis(
host="invalid_host",
port=1234,
socket_connect_timeout=1,
decode_responses=False,
)
# Should degrade gracefully
lock._last_refresh = 0 # Force refresh past rate limit
assert lock.refresh() is False
assert lock._last_refresh == 0
# Restore Redis and verify can acquire again
lock.redis = original_redis
# Wait for original lock to expire (use longer wait for 3s timeout)
time.sleep(4)
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
assert new_lock.try_acquire() == owner_id
if __name__ == "__main__":
# Run specific test for quick validation
pytest.main([__file__, "-v"])

View File

@@ -1,6 +1,5 @@
import logging
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
@@ -10,7 +9,6 @@ from backend.data.execution import (
get_execution_kv_data,
get_graph_execution_meta,
get_graph_executions,
get_graph_executions_count,
get_latest_node_execution,
get_node_execution,
get_node_executions,
@@ -30,17 +28,14 @@ from backend.data.graph import (
get_node,
)
from backend.data.notifications import (
clear_all_user_notification_batches,
create_or_add_to_user_notification_batch,
empty_user_notification_batch,
get_all_batches_by_type,
get_user_notification_batch,
get_user_notification_oldest_message_in_batch,
remove_notifications_from_batch,
)
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_by_id,
get_user_email_by_id,
get_user_email_verification,
get_user_integrations,
@@ -58,10 +53,8 @@ from backend.util.service import (
)
from backend.util.settings import Config
if TYPE_CHECKING:
from fastapi import FastAPI
config = Config()
_user_credit_model = get_user_credit_model()
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
@@ -70,27 +63,24 @@ R = TypeVar("R")
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.spend_credits(user_id, cost, metadata)
return await _user_credit_model.spend_credits(user_id, cost, metadata)
async def _get_credits(user_id: str) -> int:
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.get_credits(user_id)
return await _user_credit_model.get_credits(user_id)
class DatabaseManager(AppService):
@asynccontextmanager
async def lifespan(self, app: "FastAPI"):
async with super().lifespan(app):
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
await db.connect()
logger.info(f"[{self.service_name}] ✅ Ready")
yield
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
super().run_service()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
await db.disconnect()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
async def health_check(self) -> str:
if not db.is_connected():
@@ -121,7 +111,6 @@ class DatabaseManager(AppService):
# Executions
get_graph_executions = _(get_graph_executions)
get_graph_executions_count = _(get_graph_executions_count)
get_graph_execution_meta = _(get_graph_execution_meta)
create_graph_execution = _(create_graph_execution)
get_node_execution = _(get_node_execution)
@@ -153,18 +142,15 @@ class DatabaseManager(AppService):
# User Comms - async
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
get_user_by_id = _(get_user_by_id)
get_user_email_by_id = _(get_user_email_by_id)
get_user_email_verification = _(get_user_email_verification)
get_user_notification_preference = _(get_user_notification_preference)
# Notifications - async
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
create_or_add_to_user_notification_batch = _(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = _(empty_user_notification_batch)
remove_notifications_from_batch = _(remove_notifications_from_batch)
get_all_batches_by_type = _(get_all_batches_by_type)
get_user_notification_batch = _(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
@@ -193,7 +179,6 @@ class DatabaseManagerClient(AppServiceClient):
# Executions
get_graph_executions = _(d.get_graph_executions)
get_graph_executions_count = _(d.get_graph_executions_count)
get_graph_execution_meta = _(d.get_graph_execution_meta)
get_node_executions = _(d.get_node_executions)
update_node_execution_status = _(d.update_node_execution_status)
@@ -239,7 +224,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_node = d.get_node
get_node_execution = d.get_node_execution
get_node_executions = d.get_node_executions
get_user_by_id = d.get_user_by_id
get_user_integrations = d.get_user_integrations
upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output
@@ -257,12 +241,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_user_notification_preference = d.get_user_notification_preference
# Notifications
clear_all_user_notification_batches = d.clear_all_user_notification_batches
create_or_add_to_user_notification_batch = (
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = d.empty_user_notification_batch
remove_notifications_from_batch = d.remove_notifications_from_batch
get_all_batches_by_type = d.get_all_batches_by_type
get_user_notification_batch = d.get_user_notification_batch
get_user_notification_oldest_message_in_batch = (

View File

@@ -3,42 +3,16 @@ import logging
import os
import threading
import time
import uuid
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import sentry_sdk
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from prometheus_client import Gauge, start_http_server
from redis.asyncio.lock import Lock as AsyncRedisLock
from redis.asyncio.lock import Lock as RedisLock
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentOutputBlock
from backend.data import redis_client as redis
from backend.data.block import (
BlockInput,
BlockOutput,
BlockOutputEntry,
BlockSchema,
get_block,
)
from backend.data.credit import UsageTransactionMetadata
from backend.data.dynamic_fields import parse_execution_output
from backend.data.execution import (
ExecutionQueue,
ExecutionStatus,
GraphExecution,
GraphExecutionEntry,
NodeExecutionEntry,
NodeExecutionResult,
NodesInputMasks,
UserContext,
)
from backend.data.graph import Link, Node
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
@@ -51,21 +25,50 @@ from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.activity_status_generator import (
generate_activity_status_for_execution,
)
from backend.executor.utils import LogMetadata
from backend.notifications.notifications import queue_notification
from backend.util.exceptions import InsufficientBalanceError, ModerationError
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis_client as redis
from backend.data.block import (
BlockInput,
BlockOutput,
BlockOutputEntry,
BlockSchema,
get_block,
)
from backend.data.credit import UsageTransactionMetadata
from backend.data.execution import (
ExecutionQueue,
ExecutionStatus,
GraphExecution,
GraphExecutionEntry,
NodeExecutionEntry,
NodeExecutionResult,
NodesInputMasks,
UserContext,
)
from backend.data.graph import Link, Node
from backend.executor.utils import (
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_QUEUE_NAME,
CancelExecutionEvent,
ExecutionOutputEntry,
LogMetadata,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
parse_execution_output,
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.notifications.notifications import queue_notification
from backend.server.v2.AutoMod.manager import automod_manager
from backend.util import json
from backend.util.clients import (
@@ -81,24 +84,13 @@ from backend.util.decorator import (
error_logged,
time_measured,
)
from backend.util.exceptions import InsufficientBalanceError, ModerationError
from backend.util.file import clean_exec_files
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.metrics import DiscordChannel
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import (
continuous_retry,
func_retry,
send_rate_limited_discord_alert,
)
from backend.util.retry import continuous_retry, func_retry
from backend.util.settings import Settings
from .cluster_lock import ClusterLock
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
_logger = logging.getLogger(__name__)
logger = TruncatedLogger(_logger, prefix="[GraphExecutor]")
settings = Settings()
@@ -114,7 +106,6 @@ utilization_gauge = Gauge(
"Ratio of active graph runs to max graph workers",
)
# Thread-local storage for ExecutionProcessor instances
_tls = threading.local()
@@ -126,14 +117,10 @@ def init_worker():
def execute_graph(
graph_exec_entry: "GraphExecutionEntry",
cancel_event: threading.Event,
cluster_lock: ClusterLock,
graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event
):
"""Execute graph using thread-local ExecutionProcessor instance"""
return _tls.processor.on_graph_execution(
graph_exec_entry, cancel_event, cluster_lock
)
return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event)
T = TypeVar("T")
@@ -190,7 +177,6 @@ async def execute_node(
_input_data.inputs = input_data
if nodes_input_masks:
_input_data.nodes_input_masks = nodes_input_masks
_input_data.user_id = user_id
input_data = _input_data.model_dump()
data.inputs = input_data
@@ -225,37 +211,14 @@ async def execute_node(
extra_exec_kwargs[field_name] = credentials
output_size = 0
# sentry tracking nonsense to get user counts for blocks because isolation scopes don't work :(
scope = sentry_sdk.get_current_scope()
# save the tags
original_user = scope._user
original_tags = dict(scope._tags) if scope._tags else {}
# Set user ID for error tracking
scope.set_user({"id": user_id})
scope.set_tag("graph_id", graph_id)
scope.set_tag("node_id", node_id)
scope.set_tag("block_name", node_block.name)
scope.set_tag("block_id", node_block.id)
for k, v in (data.user_context or UserContext(timezone="UTC")).model_dump().items():
scope.set_tag(f"user_context.{k}", v)
try:
async for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
):
output_data = json.to_dict(output_data)
output_data = json.convert_pydantic_to_json(output_data)
output_size += len(json.dumps(output_data))
log_metadata.debug("Node produced output", **{output_name: output_data})
yield output_name, output_data
except Exception:
# Capture exception WITH context still set before restoring scope
sentry_sdk.capture_exception(scope=scope)
sentry_sdk.flush() # Ensure it's sent before we restore scope
# Re-raise to maintain normal error flow
raise
finally:
# Ensure credentials are released even if execution fails
if creds_lock and (await creds_lock.locked()) and (await creds_lock.owned()):
@@ -270,10 +233,6 @@ async def execute_node(
execution_stats.input_size = input_size
execution_stats.output_size = output_size
# Restore scope AFTER error has been captured
scope._user = original_user
scope._tags = original_tags
async def _enqueue_next_nodes(
db_client: "DatabaseManagerAsyncClient",
@@ -470,7 +429,7 @@ class ExecutionProcessor:
graph_id=node_exec.graph_id,
node_eid=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_name=b.name if (b := get_block(node_exec.block_id)) else "-",
block_name="-",
)
db_client = get_db_async_client()
node = await db_client.get_node(node_exec.node_id)
@@ -598,6 +557,7 @@ class ExecutionProcessor:
await persist_output(
"error", str(stats.error) or type(stats.error).__name__
)
return status
@func_retry
@@ -623,7 +583,6 @@ class ExecutionProcessor:
self,
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
):
log_metadata = LogMetadata(
logger=_logger,
@@ -682,7 +641,6 @@ class ExecutionProcessor:
cancel=cancel,
log_metadata=log_metadata,
execution_stats=exec_stats,
cluster_lock=cluster_lock,
)
exec_stats.walltime += timing_info.wall_time
exec_stats.cputime += timing_info.cpu_time
@@ -784,7 +742,6 @@ class ExecutionProcessor:
cancel: threading.Event,
log_metadata: LogMetadata,
execution_stats: GraphExecutionStats,
cluster_lock: ClusterLock,
) -> ExecutionStatus:
"""
Returns:
@@ -970,7 +927,7 @@ class ExecutionProcessor:
and execution_queue.empty()
and (running_node_execution or running_node_evaluation)
):
cluster_lock.refresh()
# There is nothing to execute, and no output to process, let's relax for a while.
time.sleep(0.1)
# loop done --------------------------------------------------
@@ -1012,31 +969,16 @@ class ExecutionProcessor:
if isinstance(e, Exception)
else Exception(f"{e.__class__.__name__}: {e}")
)
if not execution_stats.error:
execution_stats.error = str(error)
known_errors = (InsufficientBalanceError, ModerationError)
if isinstance(error, known_errors):
execution_stats.error = str(error)
return ExecutionStatus.FAILED
execution_status = ExecutionStatus.FAILED
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
# Send rate-limited Discord alert for unknown/unexpected errors
send_rate_limited_discord_alert(
"graph_execution",
error,
"unknown_error",
f"🚨 **Unknown Graph Execution Error**\n"
f"User: {graph_exec.user_id}\n"
f"Graph ID: {graph_exec.graph_id}\n"
f"Execution ID: {graph_exec.graph_exec_id}\n"
f"Error Type: {type(error).__name__}\n"
f"Error: {str(error)[:200]}{'...' if len(str(error)) > 200 else ''}\n",
)
raise
finally:
@@ -1211,9 +1153,9 @@ class ExecutionProcessor:
f"❌ **Insufficient Funds Alert**\n"
f"User: {user_email or user_id}\n"
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
f"Current balance: ${e.balance / 100:.2f}\n"
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
f"Current balance: ${e.balance/100:.2f}\n"
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
@@ -1260,9 +1202,9 @@ class ExecutionProcessor:
alert_message = (
f"⚠️ **Low Balance Alert**\n"
f"User: {user_email or user_id}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
f"Current balance: ${current_balance / 100:.2f}\n"
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
f"Current balance: ${current_balance/100:.2f}\n"
f"Transaction cost: ${transaction_cost/100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
@@ -1277,7 +1219,6 @@ class ExecutionManager(AppProcess):
super().__init__()
self.pool_size = settings.config.num_graph_workers
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
self.executor_id = str(uuid.uuid4())
self._executor = None
self._stop_consuming = None
@@ -1287,8 +1228,6 @@ class ExecutionManager(AppProcess):
self._run_thread = None
self._run_client = None
self._execution_locks = {}
@property
def cancel_thread(self) -> threading.Thread:
if self._cancel_thread is None:
@@ -1493,79 +1432,20 @@ class ExecutionManager(AppProcess):
return
graph_exec_id = graph_exec_entry.graph_exec_id
user_id = graph_exec_entry.user_id
graph_id = graph_exec_entry.graph_id
logger.info(
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}, user_id={user_id}"
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
)
# Check user rate limit before processing
try:
# Only check executions from the last 24 hours for performance
current_running_count = get_db_client().get_graph_executions_count(
user_id=user_id,
graph_id=graph_id,
statuses=[ExecutionStatus.RUNNING],
created_time_gte=datetime.now(timezone.utc) - timedelta(hours=24),
)
if (
current_running_count
>= settings.config.max_concurrent_graph_executions_per_user
):
logger.warning(
f"[{self.service_name}] Rate limit exceeded for user {user_id} on graph {graph_id}: "
f"{current_running_count}/{settings.config.max_concurrent_graph_executions_per_user} running executions"
)
_ack_message(reject=True, requeue=True)
return
except Exception as e:
logger.error(
f"[{self.service_name}] Failed to check rate limit for user {user_id}: {e}, proceeding with execution"
)
# If rate limit check fails, proceed to avoid blocking executions
# Check for local duplicate execution first
if graph_exec_id in self.active_graph_runs:
logger.warning(
f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate."
# TODO: Make this check cluster-wide, prevent duplicate runs across executor pods.
logger.error(
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
)
_ack_message(reject=True, requeue=True)
_ack_message(reject=True, requeue=False)
return
# Try to acquire cluster-wide execution lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"exec_lock:{graph_exec_id}",
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)
current_owner = cluster_lock.try_acquire()
if current_owner != self.executor_id:
# Either someone else has it or Redis is unavailable
if current_owner is not None:
logger.warning(
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
)
_ack_message(reject=True, requeue=False)
else:
logger.warning(
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
)
_ack_message(reject=True, requeue=True)
return
self._execution_locks[graph_exec_id] = cluster_lock
logger.info(
f"[{self.service_name}] Acquired cluster lock for {graph_exec_id} with executor {self.executor_id}"
)
cancel_event = threading.Event()
future = self.executor.submit(
execute_graph, graph_exec_entry, cancel_event, cluster_lock
)
future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
self._update_prompt_metrics()
@@ -1584,10 +1464,6 @@ class ExecutionManager(AppProcess):
f"[{self.service_name}] Error in run completion callback: {e}"
)
finally:
# Release the cluster-wide execution lock
if graph_exec_id in self._execution_locks:
self._execution_locks[graph_exec_id].release()
del self._execution_locks[graph_exec_id]
self._cleanup_completed_runs()
future.add_done_callback(_on_run_done)
@@ -1670,10 +1546,6 @@ class ExecutionManager(AppProcess):
f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}"
)
for graph_exec_id in self.active_graph_runs:
if lock := self._execution_locks.get(graph_exec_id):
lock.refresh()
time.sleep(wait_interval)
waited += wait_interval
@@ -1691,15 +1563,6 @@ class ExecutionManager(AppProcess):
except Exception as e:
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
# Release remaining execution locks
try:
for lock in self._execution_locks.values():
lock.release()
self._execution_locks.clear()
logger.info(f"{prefix} ✅ Released execution locks")
except Exception as e:
logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}")
# Disconnect the run execution consumer
self._stop_message_consumers(
self.run_thread,
@@ -1714,8 +1577,6 @@ class ExecutionManager(AppProcess):
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
super().cleanup()
# ------- UTILITIES ------- #
@@ -1807,18 +1668,15 @@ def update_graph_execution_state(
@asynccontextmanager
async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout):
async def synchronized(key: str, timeout: int = 60):
r = await redis.get_redis_async()
lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout)
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
try:
await lock.acquire()
yield
finally:
if await lock.locked() and await lock.owned():
try:
await lock.release()
except Exception as e:
logger.warning(f"Failed to release lock for key {key}: {e}")
await lock.release()
def increment_execution_count(user_id: str) -> int:

View File

@@ -248,7 +248,7 @@ class Scheduler(AppService):
raise UnhealthyServiceError("Scheduler is still initializing")
# Check if we're in the middle of cleanup
if self._shutting_down:
if self.cleaned_up:
return await super().health_check()
# Normal operation - check if scheduler is running
@@ -375,6 +375,7 @@ class Scheduler(AppService):
super().run_service()
def cleanup(self):
super().cleanup()
if self.scheduler:
logger.info("⏳ Shutting down scheduler...")
self.scheduler.shutdown(wait=True)
@@ -389,7 +390,7 @@ class Scheduler(AppService):
logger.info("⏳ Waiting for event loop thread to finish...")
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
super().cleanup()
logger.info("Scheduler cleanup complete.")
@expose
def add_graph_execution_schedule(

View File

@@ -4,7 +4,7 @@ import threading
import time
from collections import defaultdict
from concurrent.futures import Future
from typing import Mapping, Optional, cast
from typing import Any, Mapping, Optional, cast
from pydantic import BaseModel, JsonValue, ValidationError
@@ -20,9 +20,6 @@ from backend.data.block import (
)
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.db import prisma
# Import dynamic field utilities from centralized location
from backend.data.dynamic_fields import merge_execution_input
from backend.data.execution import (
ExecutionStatus,
GraphExecutionStats,
@@ -34,7 +31,6 @@ from backend.data.graph import GraphModel, Node
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import get_user_by_id
from backend.util.cache import cached
from backend.util.clients import (
get_async_execution_event_bus,
get_async_execution_queue,
@@ -42,12 +38,12 @@ from backend.util.clients import (
get_integration_credentials_store,
)
from backend.util.exceptions import GraphValidationError, NotFoundError
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
from backend.util.logging import TruncatedLogger
from backend.util.mock import MockObject
from backend.util.settings import Config
from backend.util.type import convert
@cached(maxsize=1000, ttl_seconds=3600)
async def get_user_context(user_id: str) -> UserContext:
"""
Get UserContext for a user, always returns a valid context with timezone.
@@ -55,11 +51,7 @@ async def get_user_context(user_id: str) -> UserContext:
"""
user_context = UserContext(timezone="UTC") # Default to UTC
try:
if prisma.is_connected():
user = await get_user_by_id(user_id)
else:
user = await get_database_manager_async_client().get_user_by_id(user_id)
user = await get_user_by_id(user_id)
if user and user.timezone and user.timezone != "not-set":
user_context.timezone = user.timezone
logger.debug(f"Retrieved user context: timezone={user.timezone}")
@@ -99,11 +91,7 @@ class LogMetadata(TruncatedLogger):
"node_id": node_id,
"block_name": block_name,
}
prefix = (
"[ExecutionManager]"
if is_structured_logging_enabled()
else f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]" # noqa
)
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
super().__init__(
logger,
max_length=max_length,
@@ -198,7 +186,195 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
# ============ Execution Input Helpers ============ #
# Dynamic field utilities are now imported from backend.data.dynamic_fields
# --------------------------------------------------------------------------- #
# Delimiters
# --------------------------------------------------------------------------- #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
_DELIMS = (LIST_SPLIT, DICT_SPLIT, OBJC_SPLIT)
# --------------------------------------------------------------------------- #
# Tokenisation utilities
# --------------------------------------------------------------------------- #
def _next_delim(s: str) -> tuple[str | None, int]:
"""
Return the *earliest* delimiter appearing in `s` and its index.
If none present → (None, -1).
"""
first: str | None = None
pos = len(s) # sentinel: larger than any real index
for d in _DELIMS:
i = s.find(d)
if 0 <= i < pos:
first, pos = d, i
return first, (pos if first else -1)
def _tokenise(path: str) -> list[tuple[str, str]] | None:
"""
Convert the raw path string (starting with a delimiter) into
[ (delimiter, identifier), … ] or None if the syntax is malformed.
"""
tokens: list[tuple[str, str]] = []
while path:
# 1. Which delimiter starts this chunk?
delim = next((d for d in _DELIMS if path.startswith(d)), None)
if delim is None:
return None # invalid syntax
# 2. Slice off the delimiter, then up to the next delimiter (or EOS)
path = path[len(delim) :]
nxt_delim, pos = _next_delim(path)
token, path = (
path[: pos if pos != -1 else len(path)],
path[pos if pos != -1 else len(path) :],
)
if token == "":
return None # empty identifier is invalid
tokens.append((delim, token))
return tokens
# --------------------------------------------------------------------------- #
# Public API parsing (flattened ➜ concrete)
# --------------------------------------------------------------------------- #
def parse_execution_output(output: BlockOutputEntry, name: str) -> JsonValue | None:
"""
Retrieve a nested value out of `output` using the flattened *name*.
On any failure (wrong name, wrong type, out-of-range, bad path)
returns **None**.
"""
base_name, data = output
# Exact match → whole object
if name == base_name:
return data
# Must start with the expected name
if not name.startswith(base_name):
return None
path = name[len(base_name) :]
if not path:
return None # nothing left to parse
tokens = _tokenise(path)
if tokens is None:
return None
cur: JsonValue = data
for delim, ident in tokens:
if delim == LIST_SPLIT:
# list[index]
try:
idx = int(ident)
except ValueError:
return None
if not isinstance(cur, list) or idx >= len(cur):
return None
cur = cur[idx]
elif delim == DICT_SPLIT:
if not isinstance(cur, dict) or ident not in cur:
return None
cur = cur[ident]
elif delim == OBJC_SPLIT:
if not hasattr(cur, ident):
return None
cur = getattr(cur, ident)
else:
return None # unreachable
return cur
def _assign(container: Any, tokens: list[tuple[str, str]], value: Any) -> Any:
"""
Recursive helper that *returns* the (possibly new) container with
`value` assigned along the remaining `tokens` path.
"""
if not tokens:
return value # leaf reached
delim, ident = tokens[0]
rest = tokens[1:]
# ---------- list ----------
if delim == LIST_SPLIT:
try:
idx = int(ident)
except ValueError:
raise ValueError("index must be an integer")
if container is None:
container = []
elif not isinstance(container, list):
container = list(container) if hasattr(container, "__iter__") else []
while len(container) <= idx:
container.append(None)
container[idx] = _assign(container[idx], rest, value)
return container
# ---------- dict ----------
if delim == DICT_SPLIT:
if container is None:
container = {}
elif not isinstance(container, dict):
container = dict(container) if hasattr(container, "items") else {}
container[ident] = _assign(container.get(ident), rest, value)
return container
# ---------- object ----------
if delim == OBJC_SPLIT:
if container is None or not isinstance(container, MockObject):
container = MockObject()
setattr(
container,
ident,
_assign(getattr(container, ident, None), rest, value),
)
return container
return value # unreachable
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Reconstruct nested objects from a *flattened* dict of key → value.
Raises ValueError on syntactically invalid list indices.
"""
merged: BlockInput = {}
for key, value in data.items():
# Split off the base name (before the first delimiter, if any)
delim, pos = _next_delim(key)
if delim is None:
merged[key] = value
continue
base, path = key[:pos], key[pos:]
tokens = _tokenise(path)
if tokens is None:
# Invalid key; treat as scalar under the raw name
merged[key] = value
continue
merged[base] = _assign(merged.get(base), tokens, value)
data.update(merged)
return data
def validate_exec(

View File

@@ -3,7 +3,7 @@ from typing import cast
import pytest
from pytest_mock import MockerFixture
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
from backend.executor.utils import merge_execution_input, parse_execution_output
from backend.util.mock import MockObject

View File

@@ -151,10 +151,7 @@ class IntegrationCredentialsManager:
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
await self.store.update_creds(user_id, fresh_credentials)
if _lock and (await _lock.locked()) and (await _lock.owned()):
try:
await _lock.release()
except Exception as e:
logger.warning(f"Failed to release OAuth refresh lock: {e}")
await _lock.release()
credentials = fresh_credentials
return credentials
@@ -187,10 +184,7 @@ class IntegrationCredentialsManager:
yield
finally:
if (await lock.locked()) and (await lock.owned()):
try:
await lock.release()
except Exception as e:
logger.warning(f"Failed to release credentials lock: {e}")
await lock.release()
async def release_all_locks(self):
"""Call this on process termination to ensure all locks are released"""

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from backend.util.cache import cached
from autogpt_libs.utils.cache import cached
if TYPE_CHECKING:
from ..providers import ProviderName
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
# --8<-- [start:load_webhook_managers]
@cached(ttl_seconds=3600)
@cached()
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
webhook_managers = {}

View File

@@ -25,11 +25,7 @@ from backend.data.notifications import (
get_summary_params_type,
)
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import (
disable_all_user_notifications,
generate_unsubscribe_link,
set_user_email_verification,
)
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.clients import get_database_manager_async_client
from backend.util.logging import TruncatedLogger
@@ -42,7 +38,7 @@ from backend.util.service import (
endpoint_to_sync,
expose,
)
from backend.util.settings import AppEnvironment, Settings
from backend.util.settings import Settings
logger = TruncatedLogger(logging.getLogger(__name__), "[NotificationManager]")
settings = Settings()
@@ -128,12 +124,6 @@ def get_routing_key(event_type: NotificationType) -> str:
def queue_notification(event: NotificationEventModel) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
# Disable in production
if settings.config.app_env == AppEnvironment.PRODUCTION:
return NotificationResult(
success=True,
message="Queueing notifications is disabled in production",
)
try:
logger.debug(f"Received Request to queue {event=}")
@@ -161,12 +151,6 @@ def queue_notification(event: NotificationEventModel) -> NotificationResult:
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
# Disable in production
if settings.config.app_env == AppEnvironment.PRODUCTION:
return NotificationResult(
success=True,
message="Queueing notifications is disabled in production",
)
try:
logger.debug(f"Received Request to queue {event=}")
@@ -229,9 +213,6 @@ class NotificationManager(AppService):
@expose
async def queue_weekly_summary(self):
# disable in prod
if settings.config.app_env == AppEnvironment.PRODUCTION:
return
# Use the existing event loop instead of creating a new one with asyncio.run()
asyncio.create_task(self._queue_weekly_summary())
@@ -245,9 +226,7 @@ class NotificationManager(AppService):
logger.info(
f"Querying for active users between {start_time} and {current_time}"
)
users = await get_database_manager_async_client(
should_retry=False
).get_active_user_ids_in_timerange(
users = await get_database_manager_async_client().get_active_user_ids_in_timerange(
end_time=current_time.isoformat(),
start_time=start_time.isoformat(),
)
@@ -274,9 +253,6 @@ class NotificationManager(AppService):
async def process_existing_batches(
self, notification_types: list[NotificationType]
):
# disable in prod
if settings.config.app_env == AppEnvironment.PRODUCTION:
return
# Use the existing event loop instead of creating a new process
asyncio.create_task(self._process_existing_batches(notification_types))
@@ -290,15 +266,15 @@ class NotificationManager(AppService):
for notification_type in notification_types:
# Get all batches for this notification type
batches = await get_database_manager_async_client(
should_retry=False
).get_all_batches_by_type(notification_type)
batches = (
await get_database_manager_async_client().get_all_batches_by_type(
notification_type
)
)
for batch in batches:
# Check if batch has aged out
oldest_message = await get_database_manager_async_client(
should_retry=False
).get_user_notification_oldest_message_in_batch(
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
batch.user_id, notification_type
)
@@ -313,9 +289,9 @@ class NotificationManager(AppService):
# If batch has aged out, process it
if oldest_message.created_at + max_delay < current_time:
recipient_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_by_id(batch.user_id)
recipient_email = await get_database_manager_async_client().get_user_email_by_id(
batch.user_id
)
if not recipient_email:
logger.error(
@@ -332,25 +308,21 @@ class NotificationManager(AppService):
f"User {batch.user_id} does not want to receive {notification_type} notifications"
)
# Clear the batch
await get_database_manager_async_client(
should_retry=False
).empty_user_notification_batch(
await get_database_manager_async_client().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
batch_data = await get_database_manager_async_client(
should_retry=False
).get_user_notification_batch(batch.user_id, notification_type)
batch_data = await get_database_manager_async_client().get_user_notification_batch(
batch.user_id, notification_type
)
if not batch_data or not batch_data.notifications:
logger.error(
f"Batch data not found for user {batch.user_id}"
)
# Clear the batch
await get_database_manager_async_client(
should_retry=False
).empty_user_notification_batch(
await get_database_manager_async_client().empty_user_notification_batch(
batch.user_id, notification_type
)
continue
@@ -386,9 +358,7 @@ class NotificationManager(AppService):
)
# Clear the batch
await get_database_manager_async_client(
should_retry=False
).empty_user_notification_batch(
await get_database_manager_async_client().empty_user_notification_batch(
batch.user_id, notification_type
)
@@ -443,13 +413,15 @@ class NotificationManager(AppService):
self, user_id: str, event_type: NotificationType
) -> bool:
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
validated_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_verification(user_id)
validated_email = (
await get_database_manager_async_client().get_user_email_verification(
user_id
)
)
preference = (
await get_database_manager_async_client(
should_retry=False
).get_user_notification_preference(user_id)
await get_database_manager_async_client().get_user_notification_preference(
user_id
)
).preferences.get(event_type, True)
# only if both are true, should we email this person
return validated_email and preference
@@ -465,9 +437,7 @@ class NotificationManager(AppService):
try:
# Get summary data from the database
summary_data = await get_database_manager_async_client(
should_retry=False
).get_user_execution_summary_data(
summary_data = await get_database_manager_async_client().get_user_execution_summary_data(
user_id=user_id,
start_time=params.start_date,
end_time=params.end_date,
@@ -554,13 +524,13 @@ class NotificationManager(AppService):
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
) -> bool:
await get_database_manager_async_client(
should_retry=False
).create_or_add_to_user_notification_batch(user_id, event_type, event)
await get_database_manager_async_client().create_or_add_to_user_notification_batch(
user_id, event_type, event
)
oldest_message = await get_database_manager_async_client(
should_retry=False
).get_user_notification_oldest_message_in_batch(user_id, event_type)
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
user_id, event_type
)
if not oldest_message:
logger.error(
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
@@ -610,9 +580,11 @@ class NotificationManager(AppService):
return False
logger.debug(f"Processing immediate notification: {event}")
recipient_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_by_id(event.user_id)
recipient_email = (
await get_database_manager_async_client().get_user_email_by_id(
event.user_id
)
)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -647,9 +619,11 @@ class NotificationManager(AppService):
return False
logger.info(f"Processing batch notification: {event}")
recipient_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_by_id(event.user_id)
recipient_email = (
await get_database_manager_async_client().get_user_email_by_id(
event.user_id
)
)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -668,9 +642,11 @@ class NotificationManager(AppService):
if not should_send:
logger.info("Batch not old enough to send")
return False
batch = await get_database_manager_async_client(
should_retry=False
).get_user_notification_batch(event.user_id, event.type)
batch = (
await get_database_manager_async_client().get_user_notification_batch(
event.user_id, event.type
)
)
if not batch or not batch.notifications:
logger.error(f"Batch not found for user {event.user_id}")
return False
@@ -681,7 +657,6 @@ class NotificationManager(AppService):
get_notif_data_type(db_event.type)
].model_validate(
{
"id": db_event.id, # Include ID from database
"user_id": event.user_id,
"type": db_event.type,
"data": db_event.data,
@@ -704,9 +679,6 @@ class NotificationManager(AppService):
chunk_sent = False
for attempt_size in [chunk_size, 50, 25, 10, 5, 1]:
chunk = batch_messages[i : i + attempt_size]
chunk_ids = [
msg.id for msg in chunk if msg.id
] # Extract IDs for removal
try:
# Try to render the email to check its size
@@ -733,23 +705,6 @@ class NotificationManager(AppService):
user_unsub_link=unsub_link,
)
# Remove successfully sent notifications immediately
if chunk_ids:
try:
await get_database_manager_async_client(
should_retry=False
).remove_notifications_from_batch(
event.user_id, event.type, chunk_ids
)
logger.info(
f"Removed {len(chunk_ids)} sent notifications from batch"
)
except Exception as e:
logger.error(
f"Failed to remove sent notifications: {e}"
)
# Continue anyway - better to risk duplicates than lose emails
# Track successful sends
successfully_sent_count += len(chunk)
@@ -767,137 +722,13 @@ class NotificationManager(AppService):
i += len(chunk)
chunk_sent = True
break
else:
# Message is too large even after size reduction
if attempt_size == 1:
logger.error(
f"Failed to send notification at index {i}: "
f"Single notification exceeds email size limit "
f"({len(test_message):,} chars > {MAX_EMAIL_SIZE:,} chars). "
f"Removing permanently from batch - will not retry."
)
# Remove the oversized notification permanently - it will NEVER fit
if chunk_ids:
try:
await get_database_manager_async_client(
should_retry=False
).remove_notifications_from_batch(
event.user_id, event.type, chunk_ids
)
logger.info(
f"Removed oversized notification {chunk_ids[0]} from batch permanently"
)
except Exception as e:
logger.error(
f"Failed to remove oversized notification: {e}"
)
failed_indices.append(i)
i += 1
chunk_sent = True
break
# Try smaller chunk size
continue
except Exception as e:
# Check if it's a Postmark API error
if attempt_size == 1:
# Single notification failed - determine the actual cause
error_message = str(e).lower()
error_type = type(e).__name__
# Check for HTTP 406 - Inactive recipient (common in Postmark errors)
if "406" in error_message or "inactive" in error_message:
logger.warning(
f"Failed to send notification at index {i}: "
f"Recipient marked as inactive by Postmark. "
f"Error: {e}. Disabling ALL notifications for this user."
)
# 1. Mark email as unverified
try:
await set_user_email_verification(
event.user_id, False
)
logger.info(
f"Set email verification to false for user {event.user_id}"
)
except Exception as deactivation_error:
logger.error(
f"Failed to deactivate email for user {event.user_id}: "
f"{deactivation_error}"
)
# 2. Disable all notification preferences
try:
await disable_all_user_notifications(event.user_id)
logger.info(
f"Disabled all notification preferences for user {event.user_id}"
)
except Exception as disable_error:
logger.error(
f"Failed to disable notification preferences: {disable_error}"
)
# 3. Clear ALL notification batches for this user
try:
await get_database_manager_async_client(
should_retry=False
).clear_all_user_notification_batches(event.user_id)
logger.info(
f"Cleared ALL notification batches for user {event.user_id}"
)
except Exception as remove_error:
logger.error(
f"Failed to clear batches for inactive recipient: {remove_error}"
)
# Stop processing - we've nuked everything for this user
return True
# Check for HTTP 422 - Malformed data
elif (
"422" in error_message
or "unprocessable" in error_message
):
logger.error(
f"Failed to send notification at index {i}: "
f"Malformed notification data rejected by Postmark. "
f"Error: {e}. Removing from batch permanently."
)
# Remove from batch - 422 means bad data that won't fix itself
if chunk_ids:
try:
await get_database_manager_async_client(
should_retry=False
).remove_notifications_from_batch(
event.user_id, event.type, chunk_ids
)
logger.info(
"Removed malformed notification from batch permanently"
)
except Exception as remove_error:
logger.error(
f"Failed to remove malformed notification: {remove_error}"
)
# Check if it's a ValueError for size limit
elif (
isinstance(e, ValueError)
and "too large" in error_message
):
logger.error(
f"Failed to send notification at index {i}: "
f"Notification size exceeds email limit. "
f"Error: {e}. Skipping this notification."
)
# Other API errors
else:
logger.error(
f"Failed to send notification at index {i}: "
f"Email API error ({error_type}): {e}. "
f"Skipping this notification."
)
# Even single notification is too large
logger.error(
f"Single notification too large to send: {e}. "
f"Skipping notification at index {i}"
)
failed_indices.append(i)
i += 1
chunk_sent = True
@@ -911,20 +742,18 @@ class NotificationManager(AppService):
failed_indices.append(i)
i += 1
# Check what remains in the batch (notifications are removed as sent)
remaining_batch = await get_database_manager_async_client(
should_retry=False
).get_user_notification_batch(event.user_id, event.type)
if not remaining_batch or not remaining_batch.notifications:
# Only empty the batch if ALL notifications were sent successfully
if successfully_sent_count == len(batch_messages):
logger.info(
f"All {successfully_sent_count} notifications sent and removed from batch"
f"Successfully sent all {successfully_sent_count} notifications, clearing batch"
)
await get_database_manager_async_client().empty_user_notification_batch(
event.user_id, event.type
)
else:
remaining_count = len(remaining_batch.notifications)
logger.warning(
f"Sent {successfully_sent_count} notifications. "
f"{remaining_count} remain in batch for retry due to errors."
f"Only sent {successfully_sent_count} of {len(batch_messages)} notifications. "
f"Failed indices: {failed_indices}. Batch will be retained for retry."
)
return True
except Exception as e:
@@ -942,9 +771,11 @@ class NotificationManager(AppService):
logger.info(f"Processing summary notification: {model}")
recipient_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_by_id(event.user_id)
recipient_email = (
await get_database_manager_async_client().get_user_email_by_id(
event.user_id
)
)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -1017,14 +848,10 @@ class NotificationManager(AppService):
logger.exception(f"Fatal error in consumer for {queue_name}: {e}")
raise
def run_service(self):
# Queue the main _run_service task
asyncio.run_coroutine_threadsafe(self._run_service(), self.shared_event_loop)
# Start the main event loop
super().run_service()
@continuous_retry()
def run_service(self):
self.run_and_wait(self._run_service())
async def _run_service(self):
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
@@ -1090,10 +917,9 @@ class NotificationManager(AppService):
def cleanup(self):
"""Cleanup service resources"""
self.running = False
logger.info("⏳ Disconnecting RabbitMQ...")
self.run_and_wait(self.rabbitmq_service.disconnect())
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
self.run_and_wait(self.rabbitmq_service.disconnect())
class NotificationManagerClient(AppServiceClient):

View File

@@ -1,598 +0,0 @@
"""Tests for notification error handling in NotificationManager."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from prisma.enums import NotificationType
from backend.data.notifications import AgentRunData, NotificationEventModel
from backend.notifications.notifications import NotificationManager
class TestNotificationErrorHandling:
"""Test cases for notification error handling in NotificationManager."""
@pytest.fixture
def notification_manager(self):
"""Create a NotificationManager instance for testing."""
with patch("backend.notifications.notifications.AppService.__init__"):
manager = NotificationManager()
manager.email_sender = MagicMock()
# Mock the _get_template method used by _process_batch
template_mock = Mock()
template_mock.base_template = "base"
template_mock.subject_template = "subject"
template_mock.body_template = "body"
manager.email_sender._get_template = Mock(return_value=template_mock)
# Mock the formatter
manager.email_sender.formatter = Mock()
manager.email_sender.formatter.format_email = Mock(
return_value=("subject", "body content")
)
manager.email_sender.formatter.env = Mock()
manager.email_sender.formatter.env.globals = {
"base_url": "http://example.com"
}
return manager
@pytest.fixture
def sample_batch_event(self):
"""Create a sample batch event for testing."""
return NotificationEventModel(
type=NotificationType.AGENT_RUN,
user_id="user_1",
created_at=datetime.now(timezone.utc),
data=AgentRunData(
agent_name="Test Agent",
credits_used=10.0,
execution_time=5.0,
node_count=3,
graph_id="graph_1",
outputs=[],
),
)
@pytest.fixture
def sample_batch_notifications(self):
"""Create sample batch notifications for testing."""
notifications = []
for i in range(3):
notification = Mock()
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
return notifications
@pytest.mark.asyncio
async def test_406_stops_all_processing_for_user(
self, notification_manager, sample_batch_event
):
"""Test that 406 inactive recipient error stops ALL processing for that user."""
with patch("backend.notifications.notifications.logger"), patch(
"backend.notifications.notifications.set_user_email_verification",
new_callable=AsyncMock,
) as mock_set_verification, patch(
"backend.notifications.notifications.disable_all_user_notifications",
new_callable=AsyncMock,
) as mock_disable_all, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
return_value=Mock(notifications=notifications)
)
mock_db.clear_all_user_notification_batches = AsyncMock()
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Track calls
call_count = [0]
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
current_call = call_count[0]
call_count[0] += 1
# First two succeed, third hits 406
if current_call < 2:
return None
else:
raise Exception("Recipient marked as inactive (406)")
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
# Only 3 calls should have been made (2 successful, 1 failed with 406)
assert call_count[0] == 3
# User should be deactivated
mock_set_verification.assert_called_once_with("user_1", False)
mock_disable_all.assert_called_once_with("user_1")
mock_db.clear_all_user_notification_batches.assert_called_once_with(
"user_1"
)
# No further processing should occur after 406
@pytest.mark.asyncio
async def test_422_permanently_removes_malformed_notification(
self, notification_manager, sample_batch_event
):
"""Test that 422 error permanently removes the malformed notification from batch and continues with others."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
side_effect=[
Mock(notifications=notifications),
Mock(notifications=[]), # Empty after processing
]
)
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Track calls
call_count = [0]
successful_indices = []
removed_notification_ids = []
# Capture what gets removed
def remove_side_effect(user_id, notif_type, notif_ids):
removed_notification_ids.extend(notif_ids)
return None
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
current_call = call_count[0]
call_count[0] += 1
# Index 2 has malformed data (422)
if current_call == 2:
raise Exception(
"Unprocessable entity (422): Malformed email data"
)
else:
successful_indices.append(current_call)
return None
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
assert call_count[0] == 5 # All 5 attempted
assert len(successful_indices) == 4 # 4 succeeded (all except index 2)
assert 2 not in successful_indices # Index 2 failed
# Verify 422 error was logged
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
assert any(
"422" in call or "malformed" in call.lower() for call in error_calls
)
# Verify all notifications were removed (4 successful + 1 malformed)
assert mock_db.remove_notifications_from_batch.call_count == 5
assert (
"notif_2" in removed_notification_ids
) # Malformed one was removed permanently
@pytest.mark.asyncio
async def test_oversized_notification_permanently_removed(
self, notification_manager, sample_batch_event
):
"""Test that oversized notifications are permanently removed from batch but others continue."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
side_effect=[
Mock(notifications=notifications),
Mock(notifications=[]), # Empty after processing
]
)
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Override formatter to simulate oversized on index 3
# original_format = notification_manager.email_sender.formatter.format_email
def format_side_effect(*args, **kwargs):
# Check if we're formatting index 3
data = kwargs.get("data", {}).get("notifications", [])
if data and len(data) == 1:
# Check notification content to identify index 3
if any(
"Test Agent 3" in str(n.data)
for n in data
if hasattr(n, "data")
):
# Return oversized message for index 3
return ("subject", "x" * 5_000_000) # Over 4.5MB limit
return ("subject", "normal sized content")
notification_manager.email_sender.formatter.format_email = Mock(
side_effect=format_side_effect
)
# Track calls
successful_indices = []
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
# Track which notification was sent based on content
for i, notif in enumerate(notifications):
if any(
f"Test Agent {i}" in str(n.data)
for n in data
if hasattr(n, "data")
):
successful_indices.append(i)
return None
return None
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
assert (
len(successful_indices) == 4
) # Only 4 sent (index 3 skipped due to size)
assert 3 not in successful_indices # Index 3 was not sent
# Verify oversized error was logged
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
assert any(
"exceeds email size limit" in call or "oversized" in call.lower()
for call in error_calls
)
@pytest.mark.asyncio
async def test_generic_api_error_keeps_notification_for_retry(
self, notification_manager, sample_batch_event
):
"""Test that generic API errors keep notifications in batch for retry while others continue."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Notification that failed with generic error
failed_notifications = [notifications[1]] # Only index 1 remains for retry
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
side_effect=[
Mock(notifications=notifications),
Mock(
notifications=failed_notifications
), # Failed ones remain for retry
]
)
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Track calls
successful_indices = []
failed_indices = []
removed_notification_ids = []
# Capture what gets removed
def remove_side_effect(user_id, notif_type, notif_ids):
removed_notification_ids.extend(notif_ids)
return None
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
# Track which notification based on content
for i, notif in enumerate(notifications):
if any(
f"Test Agent {i}" in str(n.data)
for n in data
if hasattr(n, "data")
):
# Index 1 has generic API error
if i == 1:
failed_indices.append(i)
raise Exception("Network timeout - temporary failure")
else:
successful_indices.append(i)
return None
return None
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
assert len(successful_indices) == 4 # 4 succeeded (0, 2, 3, 4)
assert len(failed_indices) == 1 # 1 failed
assert 1 in failed_indices # Index 1 failed
# Verify generic error was logged
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
assert any(
"api error" in call.lower() or "skipping" in call.lower()
for call in error_calls
)
# Only successful ones should be removed from batch (failed one stays for retry)
assert mock_db.remove_notifications_from_batch.call_count == 4
assert (
"notif_1" not in removed_notification_ids
) # Failed one NOT removed (stays for retry)
assert "notif_0" in removed_notification_ids # Successful one removed
assert "notif_2" in removed_notification_ids # Successful one removed
assert "notif_3" in removed_notification_ids # Successful one removed
assert "notif_4" in removed_notification_ids # Successful one removed
@pytest.mark.asyncio
async def test_batch_all_notifications_sent_successfully(
self, notification_manager, sample_batch_event
):
"""Test successful batch processing where all notifications are sent without errors."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
side_effect=[
Mock(notifications=notifications),
Mock(notifications=[]), # Empty after all sent successfully
]
)
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Track successful sends
successful_indices = []
removed_notification_ids = []
# Capture what gets removed
def remove_side_effect(user_id, notif_type, notif_ids):
removed_notification_ids.extend(notif_ids)
return None
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
# Track which notification was sent
for i, notif in enumerate(notifications):
if any(
f"Test Agent {i}" in str(n.data)
for n in data
if hasattr(n, "data")
):
successful_indices.append(i)
return None
return None # Success
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
# All 5 notifications should be sent successfully
assert len(successful_indices) == 5
assert successful_indices == [0, 1, 2, 3, 4]
# All notifications should be removed from batch
assert mock_db.remove_notifications_from_batch.call_count == 5
assert len(removed_notification_ids) == 5
for i in range(5):
assert f"notif_{i}" in removed_notification_ids
# No errors should be logged
assert mock_logger.error.call_count == 0
# Info message about successful sends should be logged
info_calls = [call[0][0] for call in mock_logger.info.call_args_list]
assert any("sent and removed" in call.lower() for call in info_calls)

View File

@@ -14,49 +14,19 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
return "test-user-id"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
return "admin-user-id"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
return "target-user-id"
@pytest.fixture

View File

@@ -32,7 +32,6 @@ from backend.data.model import (
OAuth2Credentials,
UserIntegrations,
)
from backend.data.onboarding import complete_webhook_trigger_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
@@ -64,7 +63,7 @@ class LoginResponse(BaseModel):
state_token: str
@router.get("/{provider}/login", summary="Initiate OAuth flow")
@router.get("/{provider}/login")
async def login(
provider: Annotated[
ProviderName, Path(title="The provider to initiate an OAuth flow for")
@@ -102,7 +101,7 @@ class CredentialsMetaResponse(BaseModel):
)
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
@router.post("/{provider}/callback")
async def callback(
provider: Annotated[
ProviderName, Path(title="The target provider for this OAuth exchange")
@@ -180,7 +179,7 @@ async def callback(
)
@router.get("/credentials", summary="List Credentials")
@router.get("/credentials")
async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
@@ -221,9 +220,7 @@ async def list_credentials_by_provider(
]
@router.get(
"/{provider}/credentials/{cred_id}", summary="Get Specific Credential By ID"
)
@router.get("/{provider}/credentials/{cred_id}")
async def get_credential(
provider: Annotated[
ProviderName, Path(title="The provider to retrieve credentials for")
@@ -244,7 +241,7 @@ async def get_credential(
return credential
@router.post("/{provider}/credentials", status_code=201, summary="Create Credentials")
@router.post("/{provider}/credentials", status_code=201)
async def create_credentials(
user_id: Annotated[str, Security(get_user_id)],
provider: Annotated[
@@ -370,8 +367,6 @@ async def webhook_ingress_generic(
return
executions: list[Awaitable] = []
await complete_webhook_trigger_step(user_id)
for node in webhook.triggered_nodes:
logger.debug(f"Webhook-attached node: {node}")
if not node.is_triggered_by_event_type(event_type):

View File

@@ -1,10 +1,12 @@
import re
from typing import Set
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
class SecurityHeadersMiddleware:
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
Middleware to add security headers to responses, with cache control
disabled by default for all endpoints except those explicitly allowed.
@@ -23,8 +25,6 @@ class SecurityHeadersMiddleware:
"/api/health",
"/api/v1/health",
"/api/status",
"/api/blocks",
"/api/v1/blocks",
# Public store/marketplace pages (read-only)
"/api/store/agents",
"/api/v1/store/agents",
@@ -49,7 +49,7 @@ class SecurityHeadersMiddleware:
}
def __init__(self, app: ASGIApp):
self.app = app
super().__init__(app)
# Compile regex patterns for wildcard matching
self.cacheable_patterns = [
re.compile(pattern.replace("*", "[^/]+"))
@@ -72,42 +72,26 @@ class SecurityHeadersMiddleware:
return False
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Pure ASGI middleware implementation for better performance than BaseHTTPMiddleware."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
async def dispatch(self, request: Request, call_next):
response: Response = await call_next(request)
# Extract path from scope
path = scope["path"]
# Add general security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
# Add security headers to the response
headers = dict(message.get("headers", []))
# Add noindex header for shared execution pages
if "/public/shared" in request.url.path:
response.headers["X-Robots-Tag"] = "noindex, nofollow"
# Add general security headers (HTTP spec requires proper capitalization)
headers[b"X-Content-Type-Options"] = b"nosniff"
headers[b"X-Frame-Options"] = b"DENY"
headers[b"X-XSS-Protection"] = b"1; mode=block"
headers[b"Referrer-Policy"] = b"strict-origin-when-cross-origin"
# Default: Disable caching for all endpoints
# Only allow caching for explicitly permitted paths
if not self.is_cacheable_path(request.url.path):
response.headers["Cache-Control"] = (
"no-store, no-cache, must-revalidate, private"
)
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
# Add noindex header for shared execution pages
if "/public/shared" in path:
headers[b"X-Robots-Tag"] = b"noindex, nofollow"
# Default: Disable caching for all endpoints
# Only allow caching for explicitly permitted paths
if not self.is_cacheable_path(path):
headers[b"Cache-Control"] = (
b"no-store, no-cache, must-revalidate, private"
)
headers[b"Pragma"] = b"no-cache"
headers[b"Expires"] = b"0"
# Convert headers back to list format
message["headers"] = list(headers.items())
await send(message)
await self.app(scope, receive, send_wrapper)
return response

View File

@@ -1,6 +1,5 @@
import contextlib
import logging
import platform
from enum import Enum
from typing import Any, Optional
@@ -12,7 +11,6 @@ import uvicorn
from autogpt_libs.auth import add_auth_responses_to_openapi
from autogpt_libs.auth import verify_settings as verify_auth_settings
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.routing import APIRoute
from prisma.errors import PrismaError
@@ -72,26 +70,6 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
# Configure thread pool for FastAPI sync operation performance
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
# - Any endpoint defined with 'def' (not async def)
# - Any dependency function defined with 'def' (not async def)
# - Manual run_in_threadpool() calls (like JWT decoding)
# Default pool size is only 40 threads, causing bottlenecks under high concurrency
config = backend.util.settings.Config()
try:
import anyio.to_thread
anyio.to_thread.current_default_thread_limiter().total_tokens = (
config.fastapi_thread_pool_size
)
logger.info(
f"Thread pool size set to {config.fastapi_thread_pool_size} for sync endpoint/dependency performance"
)
except (ImportError, AttributeError) as e:
logger.warning(f"Could not configure thread pool size: {e}")
# Continue without thread pool configuration
# Ensure SDK auto-registration is patched before initializing blocks
from backend.sdk.registry import AutoRegistry
@@ -162,9 +140,6 @@ app = fastapi.FastAPI(
app.add_middleware(SecurityHeadersMiddleware)
# Add GZip compression middleware for large responses (like /api/blocks)
app.add_middleware(GZipMiddleware, minimum_size=50_000) # 50KB threshold
# Add 401 responses to authenticated endpoints in OpenAPI spec
add_auth_responses_to_openapi(app)
@@ -298,28 +273,16 @@ class AgentServer(backend.util.service.AppProcess):
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
config = backend.util.settings.Config()
uvicorn.run(
server_app,
host=backend.util.settings.Config().agent_api_host,
port=backend.util.settings.Config().agent_api_port,
log_config=None,
)
# Configure uvicorn with performance optimizations from Kludex FastAPI tips
uvicorn_config = {
"app": server_app,
"host": config.agent_api_host,
"port": config.agent_api_port,
"log_config": None,
# Use httptools for HTTP parsing (if available)
"http": "httptools",
# Only use uvloop on Unix-like systems (not supported on Windows)
"loop": "uvloop" if platform.system() != "Windows" else "auto",
}
# Only add debug in local environment (not supported in all uvicorn versions)
if config.app_env == backend.util.settings.AppEnvironment.LOCAL:
import os
# Enable asyncio debug mode via environment variable
os.environ["PYTHONASYNCIODEBUG"] = "1"
uvicorn.run(**uvicorn_config)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
@staticmethod
async def test_execute_graph(

View File

@@ -11,6 +11,7 @@ import pydantic
import stripe
from autogpt_libs.auth import get_user_id, requires_user
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from autogpt_libs.utils.cache import cached
from fastapi import (
APIRouter,
Body,
@@ -23,8 +24,6 @@ from fastapi import (
Security,
UploadFile,
)
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
@@ -39,7 +38,6 @@ from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,
TransactionHistory,
UserCredit,
get_auto_top_up,
get_user_credit_model,
set_auto_top_up,
@@ -84,11 +82,9 @@ from backend.server.model import (
UpdateTimezoneRequest,
UploadFileResponse,
)
from backend.util.cache import cached
from backend.util.clients import get_scheduler_client
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.exceptions import GraphValidationError, NotFoundError
from backend.util.json import dumps
from backend.util.settings import Settings
from backend.util.timezone_utils import (
convert_utc_time_to_user_timezone,
@@ -108,6 +104,9 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
settings = Settings()
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
# Define the API routes
v1_router = APIRouter()
@@ -264,10 +263,13 @@ async def is_onboarding_enabled():
########################################################
def _compute_blocks_sync() -> str:
@cached()
def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
"""
Synchronous function to compute blocks data.
This does the heavy lifting: instantiate 226+ blocks, compute costs, serialize.
Get cached blocks with thundering herd protection.
Uses sync_cache decorator to prevent multiple concurrent requests
from all executing the expensive block loading operation.
"""
from backend.data.credit import get_block_cost
@@ -277,27 +279,11 @@ def _compute_blocks_sync() -> str:
for block_class in block_classes.values():
block_instance = block_class()
if not block_instance.disabled:
# Get costs for this specific block class without creating another instance
costs = get_block_cost(block_instance)
# Convert BlockCost BaseModel objects to dictionaries for JSON serialization
costs_dict = [
cost.model_dump() if isinstance(cost, BaseModel) else cost
for cost in costs
]
result.append({**block_instance.to_dict(), "costs": costs_dict})
result.append({**block_instance.to_dict(), "costs": costs})
# Use our JSON utility which properly handles complex types through to_dict conversion
return dumps(result)
@cached(ttl_seconds=3600)
async def _get_cached_blocks() -> str:
"""
Async cached function with thundering herd protection.
On cache miss: runs heavy work in thread pool
On cache hit: returns cached string immediately (no thread pool needed)
"""
# Only run in thread pool on cache miss - cache hits return immediately
return await run_in_threadpool(_compute_blocks_sync)
return result
@v1_router.get(
@@ -305,28 +291,9 @@ async def _get_cached_blocks() -> str:
summary="List available blocks",
tags=["blocks"],
dependencies=[Security(requires_user)],
responses={
200: {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"items": {"additionalProperties": True, "type": "object"},
"type": "array",
"title": "Response Getv1List Available Blocks",
}
}
},
}
},
)
async def get_graph_blocks() -> Response:
# Cache hit: returns immediately, Cache miss: runs in thread pool
content = await _get_cached_blocks()
return Response(
content=content,
media_type="application/json",
)
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
return _get_cached_blocks()
@v1_router.post(
@@ -476,8 +443,7 @@ async def upload_file(
async def get_user_credits(
user_id: Annotated[str, Security(get_user_id)],
) -> dict[str, int]:
user_credit_model = await get_user_credit_model(user_id)
return {"credits": await user_credit_model.get_credits(user_id)}
return {"credits": await _user_credit_model.get_credits(user_id)}
@v1_router.post(
@@ -489,8 +455,9 @@ async def get_user_credits(
async def request_top_up(
request: RequestTopUp, user_id: Annotated[str, Security(get_user_id)]
):
user_credit_model = await get_user_credit_model(user_id)
checkout_url = await user_credit_model.top_up_intent(user_id, request.credit_amount)
checkout_url = await _user_credit_model.top_up_intent(
user_id, request.credit_amount
)
return {"checkout_url": checkout_url}
@@ -505,8 +472,7 @@ async def refund_top_up(
transaction_key: str,
metadata: dict[str, str],
) -> int:
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.top_up_refund(user_id, transaction_key, metadata)
return await _user_credit_model.top_up_refund(user_id, transaction_key, metadata)
@v1_router.patch(
@@ -516,8 +482,7 @@ async def refund_top_up(
dependencies=[Security(requires_user)],
)
async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
user_credit_model = await get_user_credit_model(user_id)
await user_credit_model.fulfill_checkout(user_id=user_id)
await _user_credit_model.fulfill_checkout(user_id=user_id)
return Response(status_code=200)
@@ -531,23 +496,18 @@ async def configure_user_auto_top_up(
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
) -> str:
if request.threshold < 0:
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
raise ValueError("Threshold must be greater than 0")
if request.amount < 500 and request.amount != 0:
raise HTTPException(
status_code=422, detail="Amount must be greater than or equal to 500"
)
if request.amount != 0 and request.amount < request.threshold:
raise HTTPException(
status_code=422, detail="Amount must be greater than or equal to threshold"
)
raise ValueError("Amount must be greater than or equal to 500")
if request.amount < request.threshold:
raise ValueError("Amount must be greater than or equal to threshold")
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
current_balance = await _user_credit_model.get_credits(user_id)
if current_balance < request.threshold:
await user_credit_model.top_up_credits(user_id, request.amount)
await _user_credit_model.top_up_credits(user_id, request.amount)
else:
await user_credit_model.top_up_credits(user_id, 0)
await _user_credit_model.top_up_credits(user_id, 0)
await set_auto_top_up(
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
@@ -595,13 +555,15 @@ async def stripe_webhook(request: Request):
event["type"] == "checkout.session.completed"
or event["type"] == "checkout.session.async_payment_succeeded"
):
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
await _user_credit_model.fulfill_checkout(
session_id=event["data"]["object"]["id"]
)
if event["type"] == "charge.dispute.created":
await UserCredit().handle_dispute(event["data"]["object"])
await _user_credit_model.handle_dispute(event["data"]["object"])
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
await UserCredit().deduct_credits(event["data"]["object"])
await _user_credit_model.deduct_credits(event["data"]["object"])
return Response(status_code=200)
@@ -615,8 +577,7 @@ async def stripe_webhook(request: Request):
async def manage_payment_method(
user_id: Annotated[str, Security(get_user_id)],
) -> dict[str, str]:
user_credit_model = await get_user_credit_model(user_id)
return {"url": await user_credit_model.create_billing_portal_session(user_id)}
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
@v1_router.get(
@@ -634,8 +595,7 @@ async def get_credit_history(
if transaction_count_limit < 1 or transaction_count_limit > 1000:
raise ValueError("Transaction count limit must be between 1 and 1000")
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.get_transaction_history(
return await _user_credit_model.get_transaction_history(
user_id=user_id,
transaction_time_ceiling=transaction_time,
transaction_count_limit=transaction_count_limit,
@@ -652,8 +612,7 @@ async def get_credit_history(
async def get_refund_requests(
user_id: Annotated[str, Security(get_user_id)],
) -> list[RefundRequest]:
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.get_refund_requests(user_id)
return await _user_credit_model.get_refund_requests(user_id)
########################################################
@@ -875,8 +834,7 @@ async def execute_graph(
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
) -> execution_db.GraphExecutionMeta:
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
current_balance = await _user_credit_model.get_credits(user_id)
if current_balance <= 0:
raise HTTPException(
status_code=402,

View File

@@ -23,13 +23,10 @@ client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user, setup_test_user):
def setup_app_auth(mock_jwt_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
# setup_test_user fixture already executed and user is created in database
# It returns the user_id which we don't need to await
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
@@ -197,12 +194,8 @@ def test_get_user_credits(
snapshot: Snapshot,
) -> None:
"""Test get user credits endpoint"""
mock_credit_model = Mock()
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
mock_credit_model.get_credits = AsyncMock(return_value=1000)
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
response = client.get("/credits")
@@ -222,14 +215,10 @@ def test_request_top_up(
snapshot: Snapshot,
) -> None:
"""Test request top up endpoint"""
mock_credit_model = Mock()
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
mock_credit_model.top_up_intent = AsyncMock(
return_value="https://checkout.example.com/session123"
)
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
request_data = {"credit_amount": 500}
@@ -272,74 +261,6 @@ def test_get_auto_top_up(
)
def test_configure_auto_top_up(
mocker: pytest_mock.MockFixture,
snapshot: Snapshot,
) -> None:
"""Test configure auto top-up endpoint - this test would have caught the enum casting bug"""
# Mock the set_auto_top_up function to avoid database operations
mocker.patch(
"backend.server.routers.v1.set_auto_top_up",
return_value=None,
)
# Mock credit model to avoid Stripe API calls
mock_credit_model = mocker.AsyncMock()
mock_credit_model.get_credits.return_value = 50 # Current balance below threshold
mock_credit_model.top_up_credits.return_value = None
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
# Test data
request_data = {
"threshold": 100,
"amount": 500,
}
response = client.post("/credits/auto-top-up", json=request_data)
# This should succeed with our fix, but would have failed before with the enum casting error
assert response.status_code == 200
assert response.json() == "Auto top-up settings updated"
def test_configure_auto_top_up_validation_errors(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test configure auto top-up endpoint validation"""
# Mock set_auto_top_up to avoid database operations for successful case
mocker.patch("backend.server.routers.v1.set_auto_top_up")
# Mock credit model to avoid Stripe API calls for the successful case
mock_credit_model = mocker.AsyncMock()
mock_credit_model.get_credits.return_value = 50
mock_credit_model.top_up_credits.return_value = None
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
# Test negative threshold
response = client.post(
"/credits/auto-top-up", json={"threshold": -1, "amount": 500}
)
assert response.status_code == 422 # Validation error
# Test amount too small (but not 0)
response = client.post(
"/credits/auto-top-up", json={"threshold": 100, "amount": 100}
)
assert response.status_code == 422 # Validation error
# Test amount = 0 (should be allowed)
response = client.post("/credits/auto-top-up", json={"threshold": 100, "amount": 0})
assert response.status_code == 200 # Should succeed
# Graphs endpoints tests
def test_get_graphs(
mocker: pytest_mock.MockFixture,

View File

@@ -11,6 +11,8 @@ from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
router = APIRouter(
prefix="/admin",
@@ -31,8 +33,7 @@ async def add_user_credits(
logger.info(
f"Admin user {admin_user_id} is adding {amount} credits to user {user_id}"
)
user_credit_model = await get_user_credit_model(user_id)
new_balance, transaction_key = await user_credit_model._add_transaction(
new_balance, transaction_key = await _user_credit_model._add_transaction(
user_id,
amount,
transaction_type=CreditTransactionType.GRANT,

View File

@@ -1,5 +1,5 @@
import json
from unittest.mock import AsyncMock, Mock
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
@@ -7,12 +7,12 @@ import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from prisma import Json
from pytest_snapshot.plugin import Snapshot
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
import backend.server.v2.admin.model as admin_model
from backend.data.model import UserTransaction
from backend.util.json import SafeJson
from backend.util.models import Pagination
app = fastapi.FastAPI()
@@ -37,14 +37,12 @@ def test_add_user_credits_success(
) -> None:
"""Test successful credit addition by admin"""
# Mock the credit model
mock_credit_model = Mock()
mock_credit_model = mocker.patch(
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
)
mock_credit_model._add_transaction = AsyncMock(
return_value=(1500, "transaction-123-uuid")
)
mocker.patch(
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
return_value=mock_credit_model,
)
request_data = {
"user_id": target_user_id,
@@ -64,17 +62,11 @@ def test_add_user_credits_success(
call_args = mock_credit_model._add_transaction.call_args
assert call_args[0] == (target_user_id, 500)
assert call_args[1]["transaction_type"] == prisma.enums.CreditTransactionType.GRANT
# Check that metadata is a SafeJson object with the expected content
assert isinstance(call_args[1]["metadata"], SafeJson)
actual_metadata = call_args[1]["metadata"]
expected_data = {
"admin_id": admin_user_id,
"reason": "Test credit grant for debugging",
}
# SafeJson inherits from Json which stores parsed data in the .data attribute
assert actual_metadata.data["admin_id"] == expected_data["admin_id"]
assert actual_metadata.data["reason"] == expected_data["reason"]
# Check that metadata is a Json object with the expected content
assert isinstance(call_args[1]["metadata"], Json)
assert call_args[1]["metadata"] == Json(
{"admin_id": admin_user_id, "reason": "Test credit grant for debugging"}
)
# Snapshot test the response
configured_snapshot.assert_match(
@@ -89,14 +81,12 @@ def test_add_user_credits_negative_amount(
) -> None:
"""Test credit deduction by admin (negative amount)"""
# Mock the credit model
mock_credit_model = Mock()
mock_credit_model = mocker.patch(
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
)
mock_credit_model._add_transaction = AsyncMock(
return_value=(200, "transaction-456-uuid")
)
mocker.patch(
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
return_value=mock_credit_model,
)
request_data = {
"user_id": "target-user-id",

View File

@@ -7,7 +7,6 @@ import fastapi
import fastapi.responses
import prisma.enums
import backend.server.v2.store.cache as store_cache
import backend.server.v2.store.db
import backend.server.v2.store.model
import backend.util.json
@@ -87,11 +86,6 @@ async def review_submission(
StoreSubmission with updated review information
"""
try:
already_approved = (
await backend.server.v2.store.db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
)
submission = await backend.server.v2.store.db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
@@ -99,11 +93,6 @@ async def review_submission(
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
state_changed = already_approved != request.is_approved
# Clear caches when the request is approved as it updates what is shown on the store
if state_changed:
store_cache.clear_all_caches()
return submission
except Exception as e:
logger.exception("Error reviewing submission: %s", e)

Some files were not shown because too many files have changed in this diff Show More