Compare commits

..

80 Commits

Author SHA1 Message Date
Nicholas Tindle
d89b84ba2b remove logs 2025-10-18 03:09:26 -05:00
Nicholas Tindle
4619b07945 try this 2025-10-18 02:53:34 -05:00
Nicholas Tindle
d43535e491 Update route.ts 2025-10-18 02:39:28 -05:00
Nicholas Tindle
a35914889a feat: turn off captcha 2025-10-18 02:36:26 -05:00
Nicholas Tindle
7c248f2d6e test: discable turnstile 2025-10-18 02:22:53 -05:00
Nicholas Tindle
d4a7ce3846 feat: aggressive logging 2025-10-18 02:16:46 -05:00
Nicholas Tindle
605a198c09 feat: add log? 2025-10-18 01:44:21 -05:00
Nicholas Tindle
a3389485a7 Merge branch 'hotfix/waitlist-error-display' of https://github.com/Significant-Gravitas/AutoGPT into hotfix/waitlist-error-display 2025-10-17 23:46:48 -05:00
Nicholas Tindle
cd439e912a fix: same thing 2025-10-17 23:37:56 -05:00
Nicholas Tindle
7b32290582 Merge branch 'dev' into hotfix/waitlist-error-display 2025-10-17 23:35:02 -05:00
Nicholas Tindle
e3137382c3 feat: add error code check 2025-10-17 23:33:31 -05:00
Nicholas Tindle
097a19141d fix(frontend): improve waitlist error display for users not on allowlist (#11196)
## Summary

This PR improves the user experience for users who are not on the
waitlist during sign-up. When a user attempts to sign up or log in with
an email that's not on the allowlist, they now see a clear, helpful
modal with a direct call-to-action to join the waitlist.

Fixes
[OPEN-2794](https://linear.app/autogpt/issue/OPEN-2794/display-waitlist-error-for-users-not-on-waitlist-during-sign-up)

## Changes

-  Updated `EmailNotAllowedModal` with improved messaging and a "Join
Waitlist" button
- 🔧 Fixed OAuth provider signup/login to properly display the waitlist
modal
- 📝 Enhanced auth-code-error page to detect and display
waitlist-specific errors
- 💬 Added helpful guidance about checking email address and Discord
support link
- 🎯 Consistent waitlist error handling across all auth flows (regular
signup, OAuth, error pages)

## Test Plan

Tested locally by:
1. Attempting signup with non-allowlisted email - modal appears 
2. Attempting Google SSO with non-allowlisted account - modal appears 
3. Modal shows "Join Waitlist" button that opens
https://agpt.co/waitlist 
4. Help text about checking email and Discord support is visible 

## Screenshots

The new waitlist modal includes:
- Clear "Join the Waitlist" title
- Explanation that platform is in closed beta
- "Join Waitlist" button (opens in new tab)
- Help text about checking email address
- Discord support link for users who need help

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2025-10-18 03:37:31 +00:00
Nicholas Tindle
65f2c04ef1 fix: lint 2025-10-17 14:02:23 -05:00
Nicholas Tindle
865abdb9e0 fix(frontend): correct waitlist error detection in auth-code-error page
- Removed incorrect 403 error code check (Supabase doesn't send HTTP codes in OAuth redirects)
- Added isWaitlistErrorFromParams() utility for OAuth callback errors
- Now properly detects P0001 and waitlist errors from error_description parameter
- Consistent error detection across all auth flows

The auth-code-error page receives errors via URL hash parameters from
Supabase OAuth redirects, not HTTP status codes. This fix ensures we
check the error description content rather than expecting a 403 code
that would never be sent.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 13:46:20 -05:00
Nicholas Tindle
b59b200bd6 formatting 2025-10-17 13:40:49 -05:00
Nicholas Tindle
e7fb4cce5a fix: resolve merge conflicts 2025-10-17 13:23:58 -05:00
Nicholas Tindle
85e2aef6ad refactor(frontend): improve waitlist error detection with centralized utilities
- Created utility functions for robust waitlist error detection
- Added multiple fallback checks: P0001 error code, message text, and table reference
- Centralized logic in utils.ts to avoid duplication
- Added privacy-conscious logging that sanitizes email addresses
- More resilient detection that handles various Supabase error formats

The error detection now checks for:
1. PostgreSQL P0001 error code (primary indicator)
2. "not allowed to register" message from trigger
3. Reference to "allowed_users" table

This makes the waitlist check more reliable even if Supabase changes
how it formats PostgreSQL trigger errors.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 13:20:26 -05:00
Nicholas Tindle
85a8fb598e Apply suggestion from @Pwuts
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2025-10-17 13:06:35 -05:00
Nicholas Tindle
ae20da8aaa fix(frontend): improve waitlist error display for users not on allowlist
- Updated EmailNotAllowedModal with clear waitlist CTA and helpful messaging
- Added "Join Waitlist" button that opens https://agpt.co/waitlist
- Fixed OAuth provider signup/login to properly display waitlist modal
- Enhanced auth-code-error page to detect and display waitlist errors
- Added helpful guidance about checking email and Discord support link
- Consistent waitlist error handling across all auth flows

Fixes OPEN-2794

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 12:37:03 -05:00
Zamil Majdy
0bb2b87c32 fix(backend): resolve UserBalance migration issues and credit spending bug (#11192)
## Summary
Fix critical UserBalance migration and spending issues affecting users
with credits from transaction history but no UserBalance records.

## Root Issues Fixed

### Issue 1: UserBalance Migration Complexity
- **Problem**: Complex data migration with timestamp logic issues and
potential race conditions
- **Solution**: Simplified to idempotent table creation only,
application handles auto-population

### Issue 2: Credit Spending Bug  
- **Problem**: Users with $10.0 from transaction history couldn't spend
$0.16
- **Root Cause**: `_add_transaction` and `_enable_transaction` only
checked UserBalance table, returning 0 balance for users without records
- **Solution**: Enhanced both methods with transaction history fallback
logic

### Issue 3: Exception Handling Inconsistency
- **Problem**: Raw SQL unique violations raised different exception
types than Prisma ORM
- **Solution**: Convert raw SQL unique violations to
`UniqueViolationError` at source

## Changes Made

### Migration Cleanup
- **Idempotent operations**: Use `CREATE TABLE IF NOT EXISTS`, `CREATE
INDEX IF NOT EXISTS`
- **Inline foreign key**: Define constraint within `CREATE TABLE`
instead of separate `ALTER TABLE`
- **Removed data migration**: Application creates UserBalance records
on-demand
- **Safe to re-run**: No errors if table/index/constraint already exists

### Credit Logic Fixes
- **Enhanced `_add_transaction`**: Added transaction history fallback in
`user_balance_lock` CTE
- **Enhanced `_enable_transaction`**: Added same fallback logic for
payment fulfillment
- **Exception normalization**: Convert raw SQL unique violations to
`UniqueViolationError`
- **Simplified `onboarding_reward`**: Use standardized
`UniqueViolationError` catching

### SQL Fallback Pattern
```sql
COALESCE(
    (SELECT balance FROM UserBalance WHERE userId = ? FOR UPDATE),
    -- Fallback: compute from transaction history if UserBalance doesn't exist
    (SELECT COALESCE(ct.runningBalance, 0) 
     FROM CreditTransaction ct 
     WHERE ct.userId = ? AND ct.isActive = true AND ct.runningBalance IS NOT NULL 
     ORDER BY ct.createdAt DESC LIMIT 1),
    0
) as balance
```

## Impact

### Before
-  Users with transaction history but no UserBalance couldn't spend
credits
-  Migration had complex timestamp logic with potential bugs  
-  Raw SQL and Prisma exceptions handled differently
-  Error: "Insufficient balance of $10.0, where this will cost $0.16"

### After  
-  Seamless spending for all users regardless of UserBalance record
existence
-  Simple, idempotent migration that's safe to re-run
-  Consistent exception handling across all credit operations
-  Automatic UserBalance record creation during first transaction
-  Backward compatible - existing users unaffected

## Business Value
- **Eliminates user frustration**: Users can spend their credits
immediately
- **Smooth migration path**: From old User.balance to new UserBalance
table
- **Better reliability**: Atomic operations with proper error handling
- **Maintainable code**: Consistent patterns across credit operations

## Test Plan
- [ ] Manual testing with users who have transaction history but no
UserBalance records
- [ ] Verify migration can be run multiple times safely
- [ ] Test spending credits works for all user scenarios
- [ ] Verify payment fulfillment (`_enable_transaction`) works correctly
- [ ] Add comprehensive test coverage for this scenario

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-17 19:46:13 +07:00
Zamil Majdy
73c0b6899a fix(backend): Remove advisory locks for atomic credit operations (#11143)
## Problem
High QPS failures on `spend_credits` operations due to lock contention
from `pg_advisory_xact_lock` causing serialization and seconds of wait
time.

## Solution 
Replace PostgreSQL advisory locks with atomic database operations using
CTEs (Common Table Expressions).

### Key Changes
- **Add persistent balance column** to User table for O(1) balance
lookups
- **Atomic CTE-based operations** for all credit transactions using
UPDATE...RETURNING pattern
- **Comprehensive concurrency tests** with 7 test scenarios including
stress testing
- **Remove all advisory lock usage** from the credit system

### Implementation Details
1. **Migration**: Adds balance column with backfill from transaction
history
2. **Atomic Operations**: All credit operations now use single atomic
CTEs that update balance and create transaction in one query
3. **Race Condition Prevention**: WHERE clauses in UPDATE statements
ensure balance never goes negative
4. **BetaUserCredit Compatibility**: Preserved monthly refill logic with
updated `_add_transaction` signature

### Performance Impact
-  Eliminated lock contention bottlenecks
-  O(1) balance lookups instead of O(n) transaction aggregation  
-  Atomic operations prevent race conditions without locks
-  Supports high QPS without serialization delays

### Testing
- All existing tests pass
- New concurrency test suite (`credit_concurrency_test.py`) with:
  - Concurrent spends from same user
  - Insufficient balance handling
  - Mixed operations (spends, top-ups, balance checks)
  - Race condition prevention
  - Integer overflow protection
  - Stress testing with 100 concurrent operations

### Breaking Changes
None - all existing APIs maintain compatibility

🤖 Generated with [Claude Code](https://claude.ai/code)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Enhanced top‑up flows with top‑up types, clearer credit→dollar
formatting, and idempotent onboarding rewards.

* **Bug Fixes**
* Fixed race conditions for concurrent spends/top‑ups, added
integer‑overflow and underflow protection, stronger input validation,
and improved refund/dispute handling.

* **Refactor**
* Persisted per‑user balance with atomic updates for reliable balances;
admin history now prefetches balances.

* **Tests**
* Added extensive concurrency, refund, ceiling/underflow and migration
test suites.

* **Chores**
* Database migration to add persisted user balance; APIKey status
extended (SUSPENDED).
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Swifty <craigswift13@gmail.com>
2025-10-17 17:05:05 +07:00
Zamil Majdy
4c853a54d7 Merge commit 'e4bc728d40332e7c2b1edec5f1b200f1917950e2' into HEAD 2025-10-17 16:43:23 +07:00
Zamil Majdy
dfdd632161 fix(backend/util): handle nested Pydantic models in SafeJson (#11188)
## Summary

Fixes a critical serialization bug introduced in PR #11187 where
`SafeJson` failed to serialize dictionaries containing Pydantic models,
causing 500 Internal Server Errors in the executor service.

## Problem

The error manifested as:
```
CRITICAL: Operation Approaching Failure Threshold: Service communication: '_call_method_async'
Current attempt: 50/50
Error: HTTPServerError: HTTP 500: Server error '500 Internal Server Error' 
for url 'http://autogpt-database-manager.prod-agpt.svc.cluster.local:8005/create_graph_execution'
```

Root cause in `create_graph_execution`
(backend/data/execution.py:656-657):
```python
"credentialInputs": SafeJson(credential_inputs) if credential_inputs else Json({})
```

Where `credential_inputs: Mapping[str, CredentialsMetaInput]` is a dict
containing Pydantic models.

After PR #11187's refactor, `_sanitize_value()` only converted top-level
BaseModel instances to dicts, but didn't handle BaseModel instances
nested inside dicts/lists/tuples. This caused Prisma's JSON serializer
to fail with:
```
TypeError: Type <class 'backend.data.model.CredentialsMetaInput'> not serializable
```

## Solution

Added BaseModel handling to `_sanitize_value()` to recursively convert
Pydantic models to dicts before sanitizing:

```python
elif isinstance(value, BaseModel):
    # Convert Pydantic models to dict and recursively sanitize
    return _sanitize_value(value.model_dump(exclude_none=True))
```

This ensures all nested Pydantic models are properly serialized
regardless of nesting depth.

## Changes

- **backend/util/json.py**: Added BaseModel check to `_sanitize_value()`
function
- **backend/util/test_json.py**: Added 6 comprehensive tests covering:
  - Dict containing Pydantic models
  - Deeply nested Pydantic models  
  - Lists of Pydantic models in dicts
  - The exact CredentialsMetaInput scenario
  - Complex mixed structures
  - Models with control characters

## Testing

 All new tests pass  
 Verified fix resolves the production 500 error  
 Code formatted with `poetry run format`

## Related

- Fixes issues introduced in PR #11187
- Related to executor service 500 errors in production

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Bentlybro <Github@bentlybro.com>
Co-authored-by: Claude <noreply@anthropic.com>
2025-10-17 09:27:09 +00:00
Zamil Majdy
e4bc728d40 Revert "Revert "fix(backend/util): rewrite SafeJson to prevent Invalid \escape errors (#11187)""
This reverts commit 8258338caf.
2025-10-17 15:25:30 +07:00
Swifty
2c6d85d15e feat(platform): Shared cache (#11150)
### Problem
When running multiple backend pods in production, requests can be routed
to different pods causing inconsistent cache states. Additionally, the
current cache implementation in `autogpt_libs` doesn't support shared
caching across processes, leading to data inconsistency and redundant
cache misses.

### Changes 🏗️

- **Moved cache implementation from autogpt_libs to backend**
(`/backend/backend/util/cache.py`)
  - Removed `/autogpt_libs/autogpt_libs/utils/cache.py`
  - Centralized cache utilities within the backend module
  - Updated all import statements across the codebase

- **Implemented Redis-based shared caching**
- Added `shared_cache` parameter to `@cached` decorator for
cross-process caching
  - Implemented Redis connection pooling for efficient cache operations
  - Added support for cache key pattern matching and bulk deletion
  - Added TTL refresh on cache access with `refresh_ttl_on_get` option

- **Enhanced cache functionality**
  - Added thundering herd protection with double-checked locking
  - Implemented thread-local caching with `@thread_cached` decorator
- Added cache management methods: `cache_clear()`, `cache_info()`,
`cache_delete()`
  - Added support for both sync and async functions

- **Updated store caching** (`/backend/server/v2/store/cache.py`)
  - Enabled shared caching for all store-related cache functions
  - Set appropriate TTL values (5-15 minutes) for different cache types
  - Added `clear_all_caches()` function for cache invalidation

- **Added Redis configuration**
  - Added Redis connection settings to backend settings
  - Configured dedicated connection pool for cache operations
  - Set up binary mode for pickle serialization

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verify Redis connection and cache operations work correctly
  - [x] Test shared cache across multiple backend instances
  - [x] Verify cache invalidation with `clear_all_caches()`
- [x] Run cache tests: `poetry run pytest
backend/backend/util/cache_test.py`
  - [x] Test thundering herd protection under concurrent load
  - [x] Verify TTL refresh functionality with `refresh_ttl_on_get=True`
  - [x] Test thread-local caching for request-scoped data
  - [x] Ensure no performance regression vs in-memory cache

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes (Redis already configured)
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
- Redis cache configuration uses existing Redis service settings
(REDIS_HOST, REDIS_PORT, REDIS_PASSWORD)
  - No new environment variables required
2025-10-17 07:56:01 +00:00
Bentlybro
8258338caf Revert "fix(backend/util): rewrite SafeJson to prevent Invalid \escape errors (#11187)"
This reverts commit e62a56e8ba.
2025-10-17 08:31:23 +01:00
Zamil Majdy
374f35874c feat(platform): Add LaunchDarkly flag for platform payment system (#11181)
## Summary

Implement selective rollout of payment functionality using LaunchDarkly
feature flags to enable gradual deployment to pilot users.

- Add `ENABLE_PLATFORM_PAYMENT` flag to control credit system behavior
- Update `get_user_credit_model` to use user-specific flag evaluation  
- Replace hardcoded `NEXT_PUBLIC_SHOW_BILLING_PAGE` with LaunchDarkly
flag
- Enable payment UI components only for flagged users
- Maintain backward compatibility with existing beta credit system
- Default to beta monthly credits when flag is disabled
- Fix tests to work with new async credit model function

## Key Changes

### Backend
- **Credit Model Selection**: The `get_user_credit_model()` function now
takes a `user_id` parameter and uses LaunchDarkly to determine which
credit model to return:
- Flag enabled → `UserCredit` (payment system enabled, no monthly
refills)
- Flag disabled → `BetaUserCredit` (current behavior with monthly
refills)
  
- **Flag Integration**: Added `ENABLE_PLATFORM_PAYMENT` flag and
integrated LaunchDarkly evaluation throughout the credit system

- **API Updates**: All credit-related endpoints now use the
user-specific credit model instead of a global instance

### Frontend
- **Dynamic UI**: Payment-related components (billing page, wallet
refill) now show/hide based on the LaunchDarkly flag
- **Removed Environment Variable**: Replaced
`NEXT_PUBLIC_SHOW_BILLING_PAGE` with runtime flag evaluation

### Testing
- **Test Fixes**: Updated all tests that referenced the removed global
`_user_credit_model` to use proper mocking of the new async function

## Deployment Strategy

This implementation enables a controlled rollout:
1. Deploy with flag disabled (default) - no behavior change for existing
users
2. Enable flag for pilot/beta users via LaunchDarkly dashboard
3. Monitor usage and feedback from pilot users
4. Gradually expand to more users
5. Eventually enable for all users once validated

## Test Plan

- [x] Unit tests pass for credit system components
- [x] Payment UI components show/hide correctly based on flag
- [x] Default behavior (flag disabled) maintains current functionality
- [x] Flag enabled users get payment system without monthly refills
- [x] Admin credit operations work correctly
- [x] Backward compatibility maintained

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-17 06:11:39 +00:00
Zamil Majdy
e62a56e8ba fix(backend/util): rewrite SafeJson to prevent Invalid \escape errors (#11187)
## Summary

Fixes the `Invalid \escape` error occurring in
`/upsert_execution_output` endpoint by completely rewriting the SafeJson
implementation.

## Problem

- Error: `POST /upsert_execution_output failed: Invalid \escape: line 1
column 36404 (char 36403)`
- Caused by data containing literal backslash-u sequences (e.g.,
`\u0000` as text, not actual null characters)
- Previous implementation tried to remove problematic escape sequences
from JSON strings
- This created invalid JSON when it removed `\\u0000` and left invalid
sequences like `\w`

## Solution

Completely rewrote SafeJson to work on Python data structures instead of
JSON strings:

1. **Direct data sanitization**: Recursively walks through dicts, lists,
and tuples to remove control characters directly from strings
2. **No JSON string manipulation**: Avoids all escape sequence parsing
issues
3. **More efficient**: Eliminates the serialize → sanitize → deserialize
cycle
4. **Preserves valid content**: Backslashes, paths, and literal text are
correctly preserved

## Changes

- Removed `POSTGRES_JSON_ESCAPES` regex (no longer needed)
- Added `_sanitize_value()` helper function for recursive sanitization
- Simplified `SafeJson()` to convert Pydantic models and sanitize data
structures
- Added `import json  # noqa: F401` for backwards compatibility

## Testing

-  Verified fix resolves the `Invalid \escape` error
-  All existing SafeJson unit tests pass
-  Problematic data with literal escape sequences no longer causes
errors
-  Code formatted with `poetry run format`

## Technical Details

**Before (JSON string approach):**
```python
# Serialize to JSON string
json_string = dumps(data)
# Remove escape sequences from string (BREAKS!)
sanitized = regex.sub("", json_string)
# Parse back (FAILS with Invalid \escape)
return Json(json.loads(sanitized))
```

**After (data structure approach):**
```python
# Convert Pydantic to dict
data = model.model_dump() if isinstance(data, BaseModel) else data
# Recursively sanitize strings in data structure
sanitized = _sanitize_value(data)
# Return as Json (no parsing needed)
return Json(sanitized)
```

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-17 05:56:08 +00:00
Abhimanyu Yadav
f3f9a60157 feat(frontend): add extra info in custom node in new builder (#11172)
Currently, we don’t add category and cost information to custom nodes in
the new builder. This means we’re rendering with the correct information
and costs are displayed accurately based on the selected discriminator
value.

<img width="441" height="781" alt="Screenshot 2025-10-15 at 2 43 33 PM"
src="https://github.com/user-attachments/assets/8199cfa7-4353-4de2-8c15-b68aa86e458c"
/>


### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] All information is displayed correctly.
- [x] I’ve tried changing the discrimination value and we’re getting the
correct cost for the selected value.
2025-10-17 04:35:22 +00:00
Bently
9469b9e2eb feat(platform/backend): Add Claude Haiku 4.5 model support (#11179)
### Changes 🏗️

- **Added Claude Haiku 4.5 model support** (`claude-haiku-4-5-20251001`)
- Added model to `LlmModel` enum in
`autogpt_platform/backend/backend/blocks/llm.py`
- Configured model metadata with 200k context window and 64k max output
tokens
- Set pricing to 4 credits per million tokens in
`backend/data/block_cost_config.py`
  
- **Classic Forge Integration**
- Added `CLAUDE4_5_HAIKU_v1` to Anthropic provider in
`classic/forge/forge/llm/providers/anthropic.py`
- Configured with $1/1M prompt tokens and $5/1M completion tokens
pricing
  - Enabled function call API support

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  
  **Test Plan:**
- [x] Verify Claude Haiku 4.5 model appears in the LLM block model
selection dropdown
- [x] Test basic text generation using Claude Haiku 4.5 in an agent
workflow

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Configuration changes</summary>

  - No environment variable changes required
  - No docker-compose changes needed
- Model configuration is handled through existing Anthropic API
integration
</details>




https://github.com/user-attachments/assets/bbc42c47-0e7c-4772-852e-55aa91f4d253

---------

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Bently <Bentlybro@users.noreply.github.com>
2025-10-16 10:11:38 +00:00
Zamil Majdy
b7ae2c2fd2 fix(backend): move DatabaseError to backend.util.exceptions for better layer separation (#11177)
## Summary

Move DatabaseError from store-specific exceptions to generic backend
exceptions for proper layer separation, while also fixing store
exception inheritance to ensure proper HTTP status codes.

## Problem

1. **Poor Layer Separation**: DatabaseError was defined in
store-specific exceptions but represents infrastructure concerns that
affect the entire backend
2. **Incorrect HTTP Status Codes**: Store exceptions inherited from
Exception instead of ValueError, causing 500 responses for client errors
3. **Reusability Issues**: Other backend modules couldn't use
DatabaseError for DB operations
4. **Blanket Catch Issues**: Store-specific catches were affecting
generic database operations

## Solution

### Move DatabaseError to Generic Location
- Move from backend.server.v2.store.exceptions to
backend.util.exceptions
- Update all 23 references in backend/server/v2/store/db.py to use new
location
- Remove from StoreError inheritance hierarchy

### Fix Complete Store Exception Hierarchy
- MediaUploadError: Changed from Exception to ValueError inheritance
(client errors → 400)
- StoreError: Changed from Exception to ValueError inheritance (business
logic errors → 400)
- Store NotFound exceptions: Changed to inherit from NotFoundError (→
404)
- DatabaseError: Now properly inherits from Exception (infrastructure
errors → 500)

## Benefits

###  Proper Layer Separation
- Database errors are infrastructure concerns, not store-specific
business logic
- Store exceptions focus on business validation and client errors  
- Clean separation between infrastructure and business logic layers

###  Correct HTTP Status Codes
- DatabaseError: 500 (server infrastructure errors)
- Store NotFound errors: 404 (via existing NotFoundError handler)
- Store validation errors: 400 (via existing ValueError handler)
- Media upload errors: 400 (client validation errors)

###  Architectural Improvements
- DatabaseError now reusable across entire backend
- Eliminates blanket catch issues affecting generic DB operations
- All store exceptions use global exception handlers properly
- Future store exceptions automatically get proper status codes

## Files Changed

- **backend/util/exceptions.py**: Add DatabaseError class
- **backend/server/v2/store/exceptions.py**: Remove DatabaseError, fix
inheritance hierarchy
- **backend/server/v2/store/db.py**: Update all DatabaseError references
to new location

## Result

-  **No more stack trace spam**: Expected business logic errors handled
properly
-  **Proper HTTP semantics**: 500 for infrastructure, 400/404 for
client errors
-  **Better architecture**: Clean layer separation and reusable
components
-  **Fixes original issue**: AgentNotFoundError now returns 404 instead
of 500

This addresses the logging issue mentioned by @zamilmajdy while also
implementing the architectural improvements suggested by @Pwuts.
2025-10-16 09:51:58 +00:00
Abhimanyu Yadav
8b995c2394 feat(frontend): add saving ability in new builder (#11148)
This PR introduces saving functionality to the new builder interface,
allowing users to save and update agent flows. The implementation
includes both UI components and backend integration for persistent
storage of agent configurations.



https://github.com/user-attachments/assets/95ee46de-2373-4484-9f34-5f09aa071c5e


### Key Features Added:

#### 1. **Save Control Component** (`NewSaveControl`)
- Added a new save control popover in the control panel with form inputs
for agent name, description, and version display
- Integrated with the new control panel as a primary action button with
a floppy disk icon
- Supports keyboard shortcuts (Ctrl+S / Cmd+S) for quick saving

#### 2. **Graph Persistence Logic**
- Implemented `useNewSaveControl` hook to handle:
  - Creating new graphs via `usePostV1CreateNewGraph`
  - Updating existing graphs via `usePutV1UpdateGraphVersion`
- Intelligent comparison to prevent unnecessary saves when no changes
are made
  - URL parameter management for flowID and flowVersion tracking

#### 3. **Loading State Management**
- Added `GraphLoadingBox` component to display a loading indicator while
graphs are being fetched
- Enhanced `useFlow` hook with loading state tracking
(`isFlowContentLoading`)
- Improved UX with clear visual feedback during graph operations

#### 4. **Component Reorganization**
- Refactored components from `NewBlockMenu` to `NewControlPanel`
directory structure for better organization:
- Moved all block menu related components under
`NewControlPanel/NewBlockMenu/`
- Separated save control into its own module
(`NewControlPanel/NewSaveControl/`)
  - Improved modularity and separation of concerns

#### 5. **State Management Enhancements**
- Added `controlPanelStore` for managing control panel states (e.g.,
save popover visibility)
- Enhanced `nodeStore` with `getBackendNodes()` method for retrieving
nodes in backend format
- Added `getBackendLinks()` to `edgeStore` for consistent link
formatting

### Technical Improvements:

- **Graph Comparison Logic**: Implemented `graphsEquivalent()` helper to
deeply compare saved and current graph states, preventing redundant
saves
- **Form Validation**: Used Zod schema validation for save form inputs
with proper constraints
- **Error Handling**: Comprehensive error handling with user-friendly
toast notifications
- **Query Invalidation**: Proper cache invalidation after successful
saves to ensure data consistency

### UI/UX Enhancements:

- Clean, modern save dialog with clear labeling and placeholder text
- Real-time version display showing the current graph version
- Disabled state for save button during operations to prevent double
submissions
- Toast notifications for success and error states
- Higher z-index for GraphLoadingBox to ensure visibility over other
elements

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Saving is working perfectly. All nodes, links, their positions,
and hardcoded data are saved correctly.
  - [x] If there are no changes, the user cannot save the graph.
2025-10-16 08:06:18 +00:00
Zamil Majdy
12b1067017 fix(backend/store): improve store exception hierarchy for proper HTTP status codes (#11176)
## Summary

Fix store exception hierarchy to prevent ERROR level stack trace spam
for expected business logic errors and ensure proper HTTP status codes.

## Problem

The original error from production logs showed AgentNotFoundError for
non-existent agents like autogpt/domain-drop-catcher was:
- Returning 500 status codes instead of 404 
- Generating ERROR level stack traces in logs for expected not found
scenarios
- Bypassing global exception handlers due to improper inheritance

## Root Cause

Store exceptions inherited from Exception instead of ValueError, causing
them to bypass the global ValueError handler (400) and fall through to
the generic Exception handler (500) with full stack traces.

## Solution

Create proper exception hierarchy for ALL store-related errors by
making:
- MediaUploadError inherit from ValueError instead of Exception
- StoreError inherit from ValueError instead of Exception  
- Store NotFound exceptions inherit from NotFoundError (which inherits
from ValueError)

## Changes Made

1. MediaUploadError: Changed from Exception to ValueError inheritance
2. StoreError: Changed from Exception to ValueError inheritance  
3. Store NotFound exceptions: Changed to inherit from NotFoundError

## Benefits

- Correct HTTP status codes: Not found errors return 404, validation
errors return 400
- No more 500 stack trace spam for expected business logic errors
- Clean consistent error handling using existing global handlers
- Future-proof: Any new store exceptions automatically get proper status
codes

## Result

- AgentNotFoundError for autogpt/domain-drop-catcher now returns 404
instead of 500
- InvalidFileTypeError, VirusDetectedError, etc. now return 400 instead
of 500
- No ERROR level stack traces for expected client errors
- Proper HTTP semantics throughout the store API
2025-10-16 04:36:49 +00:00
Zamil Majdy
ba53cb78dc fix(backend/util): comprehensive SafeJson sanitization to prevent PostgreSQL null character errors (#11174)
## Summary
Fix critical SafeJson function to properly sanitize JSON-encoded Unicode
escape sequences that were causing PostgreSQL 22P05 errors when null
characters in web scraped content were stored as "\u0000" in the
database.

## Root Cause Analysis
The existing SafeJson function in backend/util/json.py:
1. Only removed raw control characters (\x00-\x08, etc.) using
POSTGRES_CONTROL_CHARS regex
2. Failed to handle JSON-encoded Unicode escape sequences (\u0000,
\u0001, etc.)
3. When web scraping returned content with null bytes, these were
JSON-encoded as "\u0000" strings
4. PostgreSQL rejected these Unicode escape sequences, causing 22P05
errors

## Changes Made

### Enhanced SafeJson Function (backend/util/json.py:147-153)
- **Add POSTGRES_JSON_ESCAPES regex**: Comprehensive pattern targeting
all PostgreSQL-incompatible Unicode and single-char JSON escape
sequences
- **Unicode escapes**: \u0000-\u0008, \u000B-\u000C, \u000E-\u001F,
\u007F (preserves \u0009=tab, \u000A=newline, \u000D=carriage return)
- **Single-char escapes**: \b (backspace), \f (form feed) with negative
lookbehind/lookahead to preserve file paths like "C:\\file.txt"
- **Two-pass sanitization**: Remove JSON escape sequences first, then
raw characters as fallback

### Comprehensive Test Coverage (backend/util/test_json.py:219-414)
Added 11 new test methods covering:
- **Control character sanitization**: Verify dangerous characters (\x00,
\x07, \x0C, etc.) are removed while preserving safe whitespace (\t, \n,
\r)
- **Web scraping content**: Simulate SearchTheWebBlock scenarios with
null bytes and control characters
- **Code preservation**: Ensure legitimate file paths, JSON strings,
regex patterns, and programming code are preserved
- **Unicode escape handling**: Test both \u0000-style and \b/\f-style
escape sequences
- **Edge case protection**: Prevent over-matching of legitimate
sequences like "mybfile.txt" or "\\u0040"
- **Mixed content scenarios**: Verify only problematic sequences are
removed while preserving legitimate content

## Validation Results
-  All 24 SafeJson tests pass including 11 new comprehensive
sanitization tests
-  Control characters properly removed: \x00, \x01, \x08, \x0C, \x1F,
\x7F
-  Safe characters preserved: \t (tab), \n (newline), \r (carriage
return)
-  File paths preserved: "C:\\Users\\file.txt", "\\\\server\\share"
-  Programming code preserved: regex patterns, JSON strings, SQL
escapes
-  Unicode escapes sanitized: \u0000 → removed, \u0048 ("H") →
preserved if valid
-  No false positives: Legitimate sequences not accidentally removed
-  poetry run format succeeds without errors

## Impact
- **Prevents PostgreSQL 22P05 errors**: No more null character database
rejections from web scraping
- **Maintains data integrity**: Legitimate content preserved while
dangerous characters removed
- **Comprehensive protection**: Handles both raw bytes and JSON-encoded
escape sequences
- **Web scraping reliability**: SearchTheWebBlock and similar blocks now
store content safely
- **Backward compatibility**: Existing SafeJson behavior unchanged for
legitimate content

## Test Plan
- [x] All existing SafeJson tests pass (24/24)
- [x] New comprehensive sanitization tests pass (11/11)
- [x] Control character removal verified
- [x] Legitimate content preservation verified
- [x] Web scraping scenarios tested
- [x] Code formatting and type checking passes

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-15 21:25:30 +00:00
Reinier van der Leer
f9778cc87e fix(blocks): Unhide "Add to Dictionary" block's dictionary input (#11175)
The `dictionary` input on the Add to Dictionary block is hidden, even
though it is the main input.

### Changes 🏗️

- Mark `dictionary` explicitly as not advanced (so not hidden by
default)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - Trivial change, no test needed
2025-10-15 15:04:56 +00:00
Nicholas Tindle
b230b1b5cf feat(backend): Add Sentry user and tag tracking to node execution (#11170)
Integrates Sentry SDK to set user and contextual tags during node
execution for improved error tracking and user count analytics. Ensures
Sentry context is properly set and restored, and exceptions are captured
with relevant context before scope restoration.

<!-- Clearly explain the need for these changes: -->

### Changes 🏗️
Adds sentry tracking to block failures
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Test to make sure the userid and block details show up in Sentry
  - [x] make sure other errors aren't contaminated 


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- New Features
- Added conditional support for feature flags when configured, enabling
targeted rollouts and experiments without impacting unconfigured
environments.

- Chores
- Enhanced error monitoring with richer contextual data during node
execution to improve stability and diagnostics.
- Updated metrics initialization to dynamically include feature flag
integrations when available, without altering behavior for unconfigured
setups.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-10-15 14:33:08 +00:00
Reinier van der Leer
1925e77733 feat(backend): Include default input values in graph export (#11173)
Since #10323, one-time graph inputs are no longer stored on the input
nodes (#9139), so we can reasonably assume that the default value set by
the graph creator will be safe to export.

### Changes 🏗️

- Don't strip the default value from input nodes in
`NodeModel.stripped_for_export(..)`, except for inputs marked as
`secret`
- Update and expand tests for graph export secrets stripping mechanism

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Expanded tests pass
- Relatively simple change with good test coverage, no manual test
needed
2025-10-15 14:04:44 +00:00
Copilot
9bc9b53b99 fix(backend): Add channel ID support to SendDiscordMessageBlock for consistency with other Discord blocks (#11055)
## Problem

The `SendDiscordMessageBlock` only accepted channel names, while other
Discord blocks like `SendDiscordFileBlock` and `SendDiscordEmbedBlock`
accept both channel IDs and channel names. This inconsistency made it
difficult to use channel IDs with the message sending block, which is
often more reliable and direct than name-based lookup.

## Solution

Updated `SendDiscordMessageBlock` to accept both channel IDs and channel
names through the `channel_name` field, matching the implementation
pattern used in other Discord blocks.

### Changes Made

1. **Enhanced channel resolution logic** to try parsing the input as a
channel ID first, then fall back to name-based search:
   ```python
   # Try to parse as channel ID first
   try:
       channel_id = int(channel_name)
       channel = client.get_channel(channel_id)
   except ValueError:
       # Not an ID, treat as channel name
       # ... search guilds for matching channel name
   ```

2. **Updated field descriptions** to clarify the dual functionality:
- `channel_name`: Now describes that it accepts "Channel ID or channel
name"
   - `server_name`: Clarified as "only needed if using channel name"

3. **Added type checking** to ensure the resolved channel can send
messages before attempting to send

4. **Updated documentation** to reflect the new capability

## Backward Compatibility

 **Fully backward compatible**: The field name remains `channel_name`
(not renamed), and all existing workflows using channel names will
continue to work exactly as before.

 **New capability**: Users can now also provide channel IDs (e.g.,
`"123456789012345678"`) for more direct channel targeting.

## Testing

- All existing tests pass, including `SendDiscordMessageBlock` and all
other Discord block tests
- Implementation verified to match the pattern used in
`SendDiscordFileBlock` and `SendDiscordEmbedBlock`
- Code passes all linting, formatting, and type checking

Fixes https://github.com/Significant-Gravitas/AutoGPT/issues/10909

<!-- START COPILOT CODING AGENT SUFFIX -->



<details>

<summary>Original prompt</summary>

> Issue Title: SendDiscordMessage needs to take a channel id as an
option under channelname the same as the other discord blocks
> Issue Description: with how we can process the other discord blocks we
should do the same here with the identifiers being allowed to be a
channel name or id. we can't rename the field though or that will break
backwards compatibility
> Fixes
https://linear.app/autogpt/issue/OPEN-2701/senddiscordmessage-needs-to-take-a-channel-id-as-an-option-under
> 
> 
> Comment by User :
> This thread is for an agent session with githubcopilotcodingagent.
> 
> Comment by User :
> This thread is for an agent session with githubcopilotcodingagent.
> 
> Comment by User 055a3053-5ab6-449a-bcfa-990768594185:
> the ones with boxes around them need confirmed for lables but yeah its
related but not dupe
> 
> Comment by User 264d7bf4-db2a-46fa-a880-7d67b58679e6:
> this might be a duplicate since there is a related ticket but not sure
> 
> Comment by User :
> This comment thread is synced to a corresponding [GitHub
issue](https://github.com/Significant-Gravitas/AutoGPT/issues/10909).
All replies are displayed in both locations.
> 
> 


</details>


<!-- START COPILOT CODING AGENT TIPS -->
---

💬 Share your feedback on Copilot coding agent for the chance to win a
$200 gift card! Click
[here](https://survey3.medallia.com/?EAHeSx-AP01bZqG0Ld9QLQ) to start
the survey.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* New Features
* Send Discord Message block now accepts a channel ID in addition to
channel name.
  * Server name is only required when using a channel name.
* Improved channel detection and validation with clearer errors if the
channel isn’t found.

* Documentation
* Updated block documentation to reflect support for channel ID or name
and clarify when server name is needed.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ntindle <8845353+ntindle@users.noreply.github.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: Bently <Github@bentlybro.com>
2025-10-15 13:04:53 +00:00
Toran Bruce Richards
adfa75eca8 feat(blocks): Add references output pin to Fact Checker block (#11166)
Closes #11163

## Summary
Expanded the Fact Checker block to yield its references list from the
Jina AI API response.

## Changes 🏗️
- Added `Reference` TypedDict to define the structure of reference
objects
- Added `references` field to the Output schema
- Modified the `run` method to extract and yield references from the API
response
- Added fallback to empty list if references are not present

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified that the Fact Checker block schema includes the new
references field
- [x] Confirmed that references are properly extracted from the API
response when present
- [x] Tested the fallback behavior when references are not in the API
response
- [x] Ensured backward compatibility - existing functionality remains
unchanged
- [x] Verified the Reference TypedDict matches the expected API response
structure

Generated with [Claude Code](https://claude.ai/code)

## Summary by CodeRabbit

* **New Features**
* Fact-checking results now include a references list to support
verification.
* Each reference provides a URL, a key quote, and an indicator showing
whether it supports the claim.
* References are presented alongside factuality, result, and reasoning
when available; otherwise, an empty list is returned.
* Enhances transparency and traceability of fact-check outcomes without
altering existing result fields.

---------

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Toran Bruce Richards <Torantulino@users.noreply.github.com>
Co-authored-by: Bentlybro <Github@bentlybro.com>
2025-10-15 10:19:43 +00:00
seer-by-sentry[bot]
0f19d01483 fix(frontend): Improve error handling for invalid agent files (#11165)
### Changes 🏗️

<img width="672" height="761" alt="Screenshot 2025-10-14 at 16 12 50"
src="https://github.com/user-attachments/assets/9e664ade-10fe-4c09-af10-b26d10dca360"
/>


Fixes
[BUILDER-3YG](https://sentry.io/organizations/significant-gravitas/issues/6942679655/).
The issue was that: User uploaded an incompatible external agent persona
file lacking required flow graph keys (`nodes`, `links`).

- Improves error handling when an invalid agent file is uploaded.
- Provides a more user-friendly error message indicating the file must
be a valid agent.json file exported from the AutoGPT platform.
- Clears the invalid file from the form and resets the agent object to
null.

This fix was generated by Seer in Sentry, triggered by Toran Bruce
Richards. 👁️ Run ID: 1943626

Not quite right? [Click here to continue debugging with
Seer.](https://sentry.io/organizations/significant-gravitas/issues/6942679655/?seerDrawer=true)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Test that uploading an invalid agent file (e.g., missing `nodes`
or `links`) triggers the improved error handling and displays the
user-friendly error message.
- [x] Verify that the invalid file is cleared from the form after the
error, and the agent object is reset to null.

---------

Co-authored-by: seer-by-sentry[bot] <157164994+seer-by-sentry[bot]@users.noreply.github.com>
Co-authored-by: Swifty <craigswift13@gmail.com>
Co-authored-by: Lluis Agusti <hi@llu.lu>
2025-10-15 09:55:00 +00:00
Abhimanyu Yadav
112c39f6a6 fix(frontend): fix auto select credential mechanism in new builder (#11171)
We’re currently facing two problems with credentials:

1. When we change the discriminator input value, the form data
credential field should be cleaned up completely.
2. When I select a different discriminator and if that provider has a
value, it should select the latest one.

So, in this PR, I’ve encountered both issues.

### Changes 🏗️
- Updated CredentialField to utilize a new setCredential function for
managing selected credentials.
- Implemented logic to auto-select the latest credential when none is
selected and clear the credential if the provider changes.
- Improved SelectCredential component with a defaultValue prop and
adjusted styling for better UI consistency.
- Removed unnecessary console logs from helper functions to clean up the
code.

<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Credential selection works perfectly with both the discriminator
and normal addition.
  - [x] Auto-select credential is also working.
2025-10-15 08:39:05 +00:00
Toran Bruce Richards
22946f4617 feat(blocks): add dedicated Perplexity block (#11164)
Fixes #11162

## Summary

Implements a new Perplexity block that allows users to query
Perplexity's sonar models via OpenRouter with support for extracting URL
citations and annotations.

## Changes

- Add new block for Perplexity sonar models (sonar, sonar-pro,
sonar-deep-research)
- Support model selection for all three Perplexity models
- Implement annotations output pin for URL citations and source
references
- Integrate with OpenRouter API for accessing Perplexity models
- Follow existing block patterns from AI text generator block

## Test Plan

 Block successfully instantiates
 Block is properly loaded by the dynamic loading system
 Output fields include response and annotations as required

Generated with [Claude Code](https://claude.ai/code)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- New Features
- Added a Perplexity integration block to query Sonar models via
OpenRouter.
- Supports multiple model options, optional system prompt, and
adjustable max tokens.
- Returns concise responses with citation-style annotations extracted
from the model output.
  - Provides clear error messages when requests fail.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Toran Bruce Richards <Torantulino@users.noreply.github.com>
Co-authored-by: Bentlybro <Github@bentlybro.com>
2025-10-15 08:34:37 +00:00
Ubbe
938834ac83 dx(frontend): enable Next.js sourcemaps for Sentry (#11161)
## Changes 🏗️

Next.js Sourcemaps aren't working on production, followed:

- https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
- https://docs.sentry.io/organization/integrations/deployment/vercel/

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] We will see once deployed ...

### For configuration changes:

None
2025-10-15 12:47:14 +04:00
Zamil Majdy
934cb3a9c7 feat(backend): Make execution limit per user per graph and reduce to 25 (#11169)
## Summary
- Changed max_concurrent_graph_executions_per_user from 50 to 25
concurrent executions
- Updated the limit to be per user per graph instead of globally per
user
- Users can now run different graphs concurrently without being limited
by executions of other graphs
- Enhanced database query to filter by both user_id and graph_id

## Changes Made
- **Settings**: Reduced default limit from 50 to 25 and updated
description to clarify per-graph scope
- **Database Layer**: Modified `get_graph_executions_count` to accept
optional `graph_id` parameter
- **Executor Manager**: Updated rate limiting logic to check
per-user-per-graph instead of per-user globally
- **Logging**: Enhanced warning messages to include graph_id context

## Test plan
- [ ] Verify that users can run up to 25 concurrent executions of the
same graph
- [ ] Verify that users can run different graphs concurrently without
interference
- [ ] Test rate limiting behavior when limit is exceeded for a specific
graph
- [ ] Confirm logging shows correct graph_id context in rate limit
messages

## Impact
This change improves the user experience by allowing concurrent
execution of different graphs while still preventing resource exhaustion
from running too many instances of the same graph.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-15 00:02:55 +00:00
seer-by-sentry[bot]
7b8499ec69 feat(backend): Prevent duplicate slugs for store submissions (#11155)
<!-- Clearly explain the need for these changes: -->
This PR prevents users from creating multiple store submissions with the
same slug, which could lead to confusion and potential conflicts in the
marketplace. Each user's submissions should have unique slugs to ensure
proper identification and navigation.

### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->
- **Backend**: Added validation to check for existing slugs before
creating new store submissions in `backend/server/v2/store/db.py`
- **Backend**: Introduced new `SlugAlreadyInUseError` exception in
`backend/server/v2/store/exceptions.py` for clearer error handling
- **Backend**: Updated store routes to return HTTP 409 Conflict when
slug is already in use in `backend/server/v2/store/routes.py`
- **Backend**: Cleaned up test file in
`backend/server/v2/store/db_test.py`
- **Frontend**: Enhanced error handling in the publish agent modal to
display specific error messages to users in
`frontend/src/components/contextual/PublishAgentModal/components/AgentInfoStep/useAgentInfoStep.ts`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Add a store submission with a specific slug
- [x] Attempt to add another store submission with the same slug for the
same user - Verify a 409 conflict error is returned with appropriate
error message
- [x] Add a store submission with the same slug, but for a different
user - Verify the submission is successful
- [x] Verify frontend displays the specific error message when slug
conflict occurs
  - [x] Existing tests pass without modification

---------

Co-authored-by: seer-by-sentry[bot] <157164994+seer-by-sentry[bot]@users.noreply.github.com>
Co-authored-by: Swifty <craigswift13@gmail.com>
2025-10-14 11:14:00 +00:00
Abhimanyu Yadav
63076a67e1 fix(frontend): fix client side error handling in custom mutator (#11160)
- depends on https://github.com/Significant-Gravitas/AutoGPT/pull/11159

Currently, we’re not throwing errors for client-side requests in the
custom mutator. This way, we’re ignoring the client-side request error.
If we do encounter an error, we send it as a normal response object.
That’s why our onError callback on React Query mutation and hasError
isn’t working in the query. To fix this, in this PR, we’re throwing the
client-side error.

### Changes 🏗️
- We’re throwing errors for both server-side and client-side requests.  
- Why server-side? So the client cache isn’t hydrated with the error.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] All end-to-end functionality is working properly.
- [x] I’ve manually checked all the pages and they are all functioning
correctly.
2025-10-14 08:41:57 +00:00
Abhimanyu Yadav
41260a7b4a fix(frontend): fix publish agent behavior when user is logged out (#11159)
When a user clicks the “Become a Creator” button on the marketplace
page, we send an unauthorised request to the server to get a list of
agents. In this PR, I’ve fixed this by checking if the user is logged
in. If they’re not, I’ll show them a UI to log in or create an account.
 
<img width="967" height="605" alt="Screenshot 2025-10-14 at 12 04 52 PM"
src="https://github.com/user-attachments/assets/95079d9c-e6ef-4d75-9422-ce4fb138e584"
/>

### Changes
- Modify the publish agent test to detect the correct text even when the
user is logged out.
- Use Supabase helpers to check if the user is logged in. If not,
display the appropriate UI.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] The login UI is correctly displayed when the user is logged out.
- [x] The login UI is also correctly displayed when the user is logged
in.
  - [x] The login and signup buttons are working perfectly.
2025-10-14 08:41:49 +00:00
Ubbe
5f2d4643f8 feat(frontend): dynamic search terms (#11156)
## Changes 🏗️

<img width="800" height="664" alt="Screenshot 2025-10-14 at 14 09 54"
src="https://github.com/user-attachments/assets/73f6277a-6bef-40f9-b208-31aba0cfc69f"
/>

<img width="600" height="773" alt="Screenshot 2025-10-14 at 14 10 05"
src="https://github.com/user-attachments/assets/c88cb22f-1597-4216-9688-09c19030df89"
/>

Allow to manage on the fly which search terms appear on the Marketplace
page via Launch Darkly dashboard. There is a new flag configured there:
`marketplace-search-terms`:
- **enabled** → `["Foo", "Bar"]` → the terms that will appear
- **disabled** → `[ "Marketing", "SEO", "Content Creation",
"Automation", "Fun"]` → the default ones show

### Small fix

Fix the following browser console warning about `onLoadingComplete`
being deprecated...
<img width="600" height="231" alt="Screenshot 2025-10-14 at 13 55 45"
src="https://github.com/user-attachments/assets/1b26e228-0902-4554-9f8c-4839f8d4ed83"
/>


## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Ran the flag locally and verified it shows the terms set on Launch
Darkly

### For configuration changes:

Launch Darkly new flag needs to be configured on all environments.
2025-10-14 06:43:56 +01:00
Krzysztof Czerwinski
9c8652b273 feat(backend): Whitelist Onboarding Agents (#11149)
Some agents aren't suitable for onboarding. This adds per-store agent
setting to allow them for onboarding. In case no agent is allowed
fallback to the former search.

### Changes 🏗️

- Add `useForOnboarding` to `StoreListing` model and `StoreAgent` view
(with migration)
- Remove filtering of agents with empty input schema or credentials 

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Only allowed agents are displayed
- [x] Fallback to the old system in case there aren't enough allowed
agents
2025-10-13 15:05:22 +00:00
Swifty
58ef687a54 fix(platform): Disable logging store terms (#11147)
There is concern that the write load on the database may derail the
performance optimisations.
This hotfix comments out the code that adds the search terms to the db,
so we can discuss how best to do this in a way that won't bring down the
db.

### Changes 🏗️

- commented out the code to log store terms to the db

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] check search still works in dev
2025-10-13 13:17:04 +00:00
Ubbe
c7dcbc64ec fix(frontend): ask for credentials in onboarding agent run (#11146)
## Changes 🏗️

<img width="800" height="852" alt="Screenshot_2025-10-13_at_19 20 47"
src="https://github.com/user-attachments/assets/2fc150b9-1053-4e25-9018-24bcc2d93b43"
/>

<img width="800" height="669" alt="Screenshot 2025-10-13 at 19 23 41"
src="https://github.com/user-attachments/assets/9078b04e-0f65-42f3-ac4a-d2f3daa91215"
/>

- Onboarding “Run” step now renders required credentials (e.g., Google
OAuth) and includes them in execution.
- Run button remains disabled until required inputs and credentials are
provided.
- Logic extracted and strongly typed; removed any usage.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan ( _once merged
in dev..._ )
  - [ ] Select an onboarding agent that requires Google OAuth:
  - [ ] Credentials selector appears.
  - [ ] After selecting/signing in, “Run agent” enables.
  - [ ]Run succeeds and navigates to the next step.

### For configuration changes:

None
2025-10-13 12:51:45 +00:00
Ubbe
99ac206272 fix(frontend): handle websocket disconnect issue (#11144)
## Changes 🏗️

I found that if I logged out while an agent was running, sometimes
Webscokets would keep open connections but fail to connect ( given there
is no token anymore ) and cause strange behavior down the line on the
login screen.

Two root causes behind after inspecting the browser logs 🧐 
- WebSocket connections were attempted with an empty token right after
logout, yielding `wss://.../ws?token=` and repeated `1006/connection`
refused loops.
- During logout, sockets in `CONNECTING` state weren’t being closed, so
the browser kept trying to finish the handshake and were reattempted
shortly after failing

Trying to fix this like:
- Guard `connectWebSocket()` to no-op if a logout/disconnect intent is
set, and to skip connecting when no token is available.
- Treat `CONNECTING` sockets as closeable in `disconnectWebSocket()` and
clear `wsConnecting` to avoid stale pending Promises
- Left existing heartbeat/reconnect logic intact, but it now won’t run
when we’re logging out or when we can’t get a token.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Login and run an agent that takes long to run
  - [x] Logout
  - [x] Check the browser console and you don't see any socket errors
  - [x] The login screen behaves ok   

### For configuration changes:

Noop
2025-10-13 12:10:16 +00:00
Abhimanyu Yadav
f67d78df3e feat(frontend): Implement discriminator logic in the new builder’s credential system. (#11124)
- Depends on https://github.com/Significant-Gravitas/AutoGPT/pull/11107
and https://github.com/Significant-Gravitas/AutoGPT/pull/11122

In this PR, I’ve added support for discrimination. Now, users can choose
a credential type based on other input values.


https://github.com/user-attachments/assets/6cedc59b-ec84-4ae2-bb06-59d891916847

### Changes 🏗️
- Updated CredentialsField to utilize credentialProvider from schema.
- Refactored helper functions to filter credentials based on the
selected provider.
- Modified APIKeyCredentialsModal and PasswordCredentialsModal to accept
provider as a prop.
- Improved FieldTemplate to dynamically display the correct credential
provider.
- Added getCredentialProviderFromSchema function to manage
multi-provider scenarios.

### Checklist 📋

#### For code changes:
- [x] Credential input is correctly updating based on other input
values.
- [x] Credential can be added correctly.
2025-10-13 12:08:10 +00:00
Swifty
e32c509ccc feat(backend): Simplify caching to just store routes (#11140)
### Problem
Limits caching to just the main marketplace routes

### Changes 🏗️

- **Simplified store cache implementation** in
`backend/server/v2/store/cache.py`
  - Streamlined caching logic for better maintainability
  - Reduced complexity while maintaining performance
  
- **Added cache invalidation on store updates**
  - Implemented cache clearing when new agents are added to the store
- Added invalidation logic in admin store routes
(`admin_store_routes.py`)
  - Ensures all pods reflect the latest store state after modifications

- **Updated store database operations** in
`backend/server/v2/store/db.py`
  - Modified to work with the new cache structure
  
- **Added cache deletion tests** (`test_cache_delete.py`)
  - Validates cache invalidation works correctly
  - Ensures cache consistency across operations

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verify store listings are cached correctly
  - [x] Upload a new agent to the store and confirm cache is invalidated
2025-10-13 07:25:59 +00:00
seer-by-sentry[bot]
20acd8b51d fix(backend): Improve Postmark error handling and logging for notification delivery (#11052)
<!-- Clearly explain the need for these changes: -->
Fixes
[AUTOGPT-SERVER-5K6](https://sentry.io/organizations/significant-gravitas/issues/6887660207/).
The issue was that: Batch sending fails due to malformed data (422) and
inactive recipients (406); the 406 error is misclassified as a size
limit failure.

- Implements more robust error handling for Postmark API failures during
notification sending.
- Specifically handles inactive recipients (HTTP 406), malformed data
(HTTP 422), and oversized notifications.
- Adds detailed logging for each error case, including the notification
index and error message.
- Skips individual notifications that fail due to these errors,
preventing the entire batch from failing.
- Improves error handling for ValueErrors during send_templated calls,
specifically addressing oversized notifications.


This fix was generated by Seer in Sentry, triggered by Nicholas Tindle.
👁️ Run ID: 1675950

Not quite right? [Click here to continue debugging with
Seer.](https://sentry.io/organizations/significant-gravitas/issues/6887660207/?seerDrawer=true)

### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->
- Implements more robust error handling for Postmark API failures during
notification sending.
- Specifically handles inactive recipients (HTTP 406), malformed data
(HTTP 422), and oversized notifications.
- Adds detailed logging for each error case, including the notification
index and error message.
- Skips individual notifications that fail due to these errors,
preventing the entire batch from failing.
- Improves error handling for ValueErrors during send_templated calls,
specifically addressing oversized notifications.
- Also disables this in prod to prevent scaling issues until we work out
some of the more critical issues

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Test sending notifications with invalid email addresses to ensure
406 errors are handled correctly.
- [x] Test sending notifications with malformed data to ensure 422
errors are handled correctly.
- [x] Test sending oversized notifications to ensure they are skipped
and logged correctly.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- New Features
  - None

- Bug Fixes
- Individual email failures no longer abort a batch; processing
continues after per-recipient errors.
- Specific handling for inactive recipients and malformed messages to
prevent repeated delivery attempts.

- Chores
  - Improved error logging and diagnostics for email delivery scenarios.

- Tests
- Added tests covering email-sending error cases, user-deactivation on
inactive addresses, and batch-continuation behavior.

- Documentation
  - None
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: seer-by-sentry[bot] <157164994+seer-by-sentry[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2025-10-13 07:16:48 +00:00
Nicholas Tindle
a49c957467 Revert "fix(frontend/builder): Sync frontend node IDs with backend after save" (#11142)
Reverts Significant-Gravitas/AutoGPT#11075
2025-10-13 07:16:02 +00:00
Abhimanyu Yadav
cf6e724e99 feat(platform): load graph on new builder (#11141)
In this PR, I’ve added functionality to fetch a graph based on the
flowID and flowVersion provided in the URL. Once the graph is fetched,
we add the nodes and links using the graph data in a new builder.

<img width="1512" height="982" alt="Screenshot 2025-10-11 at 10 26
07 AM"
src="https://github.com/user-attachments/assets/2f66eb52-77b2-424c-86db-559ea201b44d"
/>


### Changes
- Added `get_specific_blocks` route in `routes.py`.
- Created `get_block_by_id` function in `db.py`.
- Add a new hook `useFlow.ts` to load the graph and populate it in the
flow editor.
- Updated frontend components to reflect changes in block handling.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Able to load the graph correctly.
  - [x] Able to populate it on the builder.
2025-10-11 15:28:37 +00:00
Reinier van der Leer
b67555391d fix(frontend/builder): Sync frontend node IDs with backend after save (#11075)
- Resolves #10980

Fixes unnecessary graph re-saving when no changes were made after
initial save. The issue occurred because frontend node IDs weren't
synced with backend IDs after save operations.

### Changes 🏗️

- Update actual node.id to match backend node ID after save
- Update edge references with new node IDs
- Properly sync visual editor state with backend

### Test Plan 📋

- [x] TypeScript compilation passes  
- [x] Pre-commit hooks pass
- [x] Manual test: Save graph, verify no re-save needed on subsequent
save/run
2025-10-11 01:12:19 +00:00
Zamil Majdy
05a72f4185 feat(backend): implement user rate limiting for concurrent graph executions (#11128)
## Summary
Add configurable rate limiting to prevent users from exceeding the
maximum number of concurrent graph executions, defaulting to 50 per
user.

## Changes Made

### Configuration (`backend/util/settings.py`)
- Add `max_concurrent_graph_executions_per_user` setting (default: 50,
range: 1-1000)
- Configurable via environment variables or settings file

### Database Query Function (`backend/data/execution.py`) 
- Add `get_graph_executions_count()` function for efficient count
queries
- Supports filtering by user_id, statuses, and time ranges
- Used to check current RUNNING/QUEUED executions per user

### Database Manager Integration (`backend/executor/database.py`)
- Expose `get_graph_executions_count` through DatabaseManager RPC
interface
- Follows existing patterns for database operations
- Enables proper service-to-service communication

### Rate Limiting Logic (`backend/executor/manager.py`)
- Inline rate limit check in `_handle_run_message()` before cluster lock
- Use existing `db_client` pattern for consistency
- Reject and requeue executions when limit exceeded
- Graceful error handling - proceed if rate limit check fails
- Enhanced logging with user_id and current/max execution counts

## Technical Implementation
- **Database approach**: Query actual execution statuses for accuracy
- **RPC pattern**: Use DatabaseManager client following existing
codebase patterns
- **Fail-safe design**: Proceed with execution if rate limit check fails
- **Requeue on limit**: Rejected executions are requeued for later
processing
- **Early rejection**: Check rate limit before expensive cluster lock
operations

## Rate Limiting Flow
1. Parse incoming graph execution request
2. Query database via RPC for user's current RUNNING/QUEUED execution
count
3. Compare against configured limit (default: 50)
4. If limit exceeded: reject and requeue message
5. If within limit: proceed with normal execution flow

## Configuration Example
```env
MAX_CONCURRENT_GRAPH_EXECUTIONS_PER_USER=25  # Reduce to 25 for stricter limits
```

## Test plan
- [x] Basic functionality tested - settings load correctly, database
function works
- [x] ExecutionManager imports and initializes without errors
- [x] Database manager exposes the new function through RPC
- [x] Code follows existing patterns and conventions
- [ ] Integration testing with actual rate limiting scenarios
- [ ] Performance testing to ensure minimal impact on execution pipeline

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-11 08:02:34 +07:00
Swifty
36f634c417 fix(backend): Update store agent view to return only latest version (#11065)
This PR fixes duplicate agent listings on the marketplace home page by
updating the StoreAgent view to return only the latest approved version
of each agent.

### Changes 🏗️

- Updated `StoreAgent` database view to filter for only the latest
approved version per listing
- Added CTE (Common Table Expression) `latest_versions` to efficiently
determine the maximum version for each store listing
- Modified the join logic to only include the latest version instead of
all approved versions
- Updated `versions` array field to contain only the single latest
version

**Technical details:**
- The view now uses a `latest_versions` CTE that groups by
`storeListingId` and finds `MAX(version)` for approved submissions
- Join condition ensures only the latest version is included:
`slv.version = lv.latest_version`
- This prevents duplicate entries for agents with multiple approved
versions

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified marketplace home page shows no duplicate agents
- [x] Confirmed only latest version is displayed for agents with
multiple approved versions
  - [x] Checked that agent details page still functions correctly
  - [x] Validated that run counts and ratings are still accurate

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
2025-10-10 09:31:36 +00:00
Swifty
18e169aa51 feat(platform): Log Marketplace Search Terms (#11092)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Reinier van der Leer <Pwuts@users.noreply.github.com>
2025-10-10 11:33:28 +02:00
Swifty
c5b90f7b09 feat(platform): Simplify running of core docker services (#11113)
Co-authored-by: vercel[bot] <35613825+vercel[bot]@users.noreply.github.com>
2025-10-10 11:32:46 +02:00
Ubbe
a446c1acc9 fix(frontend): improve navbar on mobile (#11137)
## Changes 🏗️

Make the navigation bar look nice across screen sizes 📱 

<img width="1229" height="388" alt="Screenshot 2025-10-10 at 17 53 48"
src="https://github.com/user-attachments/assets/037a9957-9c0b-4e2c-9ef5-af198fdce923"
/>

<img width="700" height="392" alt="Screenshot 2025-10-10 at 17 53 42"
src="https://github.com/user-attachments/assets/bf9a0f83-a528-4613-83e7-6e204078b905"
/>

<img width="500" height="377" alt="Screenshot 2025-10-10 at 17 52 24"
src="https://github.com/user-attachments/assets/2209d4f3-a41a-4700-894b-5e6e7c15fefb"
/>

<img width="300" height="381" alt="Screenshot 2025-10-10 at 17 52 16"
src="https://github.com/user-attachments/assets/1c87d545-784e-47b5-b23c-6f37cfae489b"
/>


## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Login to the platform and resize the window down
- [x] The navbar looks good across screen sizes and everything is
aligned and accessible

### For configuration changes:

None
2025-10-10 09:10:24 +00:00
Ubbe
59d242f69c fix(frontend): remove agent activity flag (#11136)
## Changes 🏗️

The Agent Activity Dropdown is now stable, it has been 100% exposed to
users on production for a few weeks, no need to have it behind a flag
anymore.

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Login to AutoGPT
- [x] The bell on the navbar is always present even if the flag on
Launch Darkly is turned off

### For configuration changes:

None
2025-10-10 09:08:42 +00:00
Abhimanyu Yadav
a2cd5d9c1f feat(frontend): add support for user password credentials in new FlowEditor (#11122)
- depends on https://github.com/Significant-Gravitas/AutoGPT/pull/11107

In this PR, I’ve added a way to add a username and password as
credentials on new builder.


https://github.com/user-attachments/assets/b896ea62-6a6d-487c-99a3-727cef4ad9a5

### Changes 🏗️
- Introduced PasswordCredentialsModal to handle user password
credentials.
- Updated useCredentialField to support user password type.
- Refactored APIKeyCredentialsModal to remove unnecessary onSuccess
prop.
- Enhanced the CredentialsField component to conditionally render the
new password modal based on supported credential types.

### Checklist 📋

#### For code changes:
- [x] Ability to add username and password correctly.
- [x] The username and password are visible in the credentials list
after adding it.
2025-10-10 07:15:21 +00:00
Abhimanyu Yadav
df5b348676 feat(frontend): add search functionality in new block menu (#11121)
- Depends on https://github.com/Significant-Gravitas/AutoGPT/pull/11120

In this PR, I’ve added a search functionality to the new block menu with
pagination.



https://github.com/user-attachments/assets/4c199997-4b5a-43c7-83b6-66abb1feb915



### Changes 🏗️
- Add a frontend for the search list with pagination functionality.
- Updated the search route to use GET method.
- Removed the SearchRequest model and replaced it with individual query
parameters.

### Checklist 📋

#### For code changes:
- [x] The search functionality is working perfectly.
- [x] If the search query doesn’t exist, it correctly displays a “No
Result” UI.
2025-10-09 12:28:12 +00:00
Bently
4856bd1f3a fix(backend): prevent sub-agent execution visibility across users (#11132)
Fixes a issue where sub-agent executions triggered by one user were
visible in the original agent author's execution library.
 ## Solution

Fixed the user_id attribution in
`autogpt_platform/backend/backend/executor/manager.py` by ensuring that
sub-agent executions always use the actual executor's user_id rather
than the agent author's user_id stored in node defaults.

### Changes
- Added user_id override in `execute_node()` function when preparing
AgentExecutorBlock input (line 194)
- Ensures sub-agent executions are correctly attributed to the user
running them, not the agent author
- Maintains proper privacy isolation between users in marketplace agent
scenarios

### Security Impact
- **Before**: When User B downloaded and ran a marketplace agent
containing sub-agents owned by User A, the sub-agent executions appeared
in User A's library
- **After**: Sub-agent executions now only appear in the library of the
user who actually ran them
- Prevents unauthorized access to execution data and user privacy
violation

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Test plan: -->
  - [x] Create an agent with sub-agents as User A
  - [x] Publish agent to marketplace
  - [x] Run the agent as User B
- [x] Verify User A cannot see User B's sub-agent executions in their
library
  - [x] Verify User B can see their own sub-agent executions
  - [x] Verify primary agent executions remain correctly filtered
2025-10-09 11:17:26 +00:00
Abhimanyu Yadav
2e1d3dd185 refactor(frontend): replace context api in new block menu with zustand store (#11120)
Currently, we use the context API for the block menu provider and to
access some of its state outside the blockMenuProvider wrapper. For
instance, in the tutorial, we need to move this wrapper higher up in the
tree, perhaps at the top of the builder tree. This will cause
unnecessary re-renders. Therefore, we should create a block menu zustand
store so that we can easily access it in other parts of the builder.

### Changes 🏗️
- Deleted `block-menu-provider.tsx` file.
- Updated BlockMenu, BlockMenuContent, BlockMenuDefaultContent, and
other components to utilize blockMenuStore instead of
BlockMenuStateProvider.
- Adjusted imports and context usage accordingly.

### Checklist 📋
- [x] Changes have been clearly listed.
- [x] Code has been tested and verified.
- [x] I’ve checked every part of the block menu where we used the
context API and it’s working perfectly.
- [x] Ability to use block menu state in other parts of the builder.
2025-10-09 11:04:42 +00:00
Abhimanyu Yadav
ff72343035 feat(frontend): add UI for sticky notes on new builder (#11123)
Currently, the new builder doesn’t support sticky notes. We’re rendering
them as normal nodes with an input, which is why I’ve added a UI for
this.

<img width="1512" height="982" alt="Screenshot 2025-10-08 at 4 12 58 PM"
src="https://github.com/user-attachments/assets/be716e45-71c6-4cc4-81ba-97313426222f"
/>

To add sticky notes, go to the search menu of the block menu and search
for “Note block”. Then, add them from there.

### Changes 🏗️
- Updated CustomNodeData to include uiType.
- Conditional rendering in CustomNode based on uiType.
- Added a custom sticky note UI component called `StickyNoteBlock.tsx`.
- Adjusted FormCreator and FieldTemplate to pass and utilize uiType.
- Enhanced TextInputWidget to render differently based on uiType.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Able to attach sticky notes to the builder.
- [x] Able to accurately capture data while writing on sticky notes and
data is persistent also
2025-10-09 06:48:19 +00:00
Abhimanyu Yadav
7982c34450 feat(frontend): add oauth2 credential support in new builder (#11107)
In this PR, I have added support of oAuth2 in new builder.


https://github.com/user-attachments/assets/89472ebb-8ec2-467a-9824-79a80a71af8a

### Changes 🏗️
- Updated the FlowEditor to support OAuth2 credential selection.
- Improved the UI for API key and OAuth2 modals, enhancing user
experience.
- Refactored credential field components for better modularity and
maintainability.
- Updated OpenAPI documentation to reflect changes in OAuth flow
endpoints.

### Checklist 📋
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Able to create OAuth credentials
  - [x] OAuth2 is correctly selected using the Credential Selector.
2025-10-09 06:47:15 +00:00
Zamil Majdy
59c27fe248 feat(backend): implement comprehensive rate-limited Discord alerting system (#11106)
## Summary
Implement comprehensive Discord alerting system with intelligent rate
limiting to prevent spam and provide proper visibility into system
failures across retry mechanisms and execution errors.

## Key Features

### 🚨 Rate-Limited Discord Alerting Infrastructure
- **Reusable rate-limited alerts**: `send_rate_limited_discord_alert()`
function for any Discord alerts
- **5-minute rate limiting**: Prevents spam for identical error
signatures (function+error+context)
- **Thread-safe**: Proper locking for concurrent alert attempts
- **Configurable channels**: Support custom Discord channels or default
to PLATFORM
- **Graceful failure handling**: Alert failures don't break main
application flow

### 🔄 Enhanced Retry Alert System
- **Unified threshold alerting**: Both general retries and
infrastructure retries alert at EXCESSIVE_RETRY_THRESHOLD (50 attempts)
- **Critical retry alerts**: Early warning when operations approach
failure threshold
- **Infrastructure monitoring**: Dedicated alerts for database, Redis,
RabbitMQ connection issues
- **Rate limited**: All retry alerts use rate limiting to prevent
overwhelming Discord channels

### 📊 Unknown Execution Error Alerts
- **Automatic error detection**: Alert for unexpected graph execution
failures
- **Rich context**: Include user ID, graph ID, execution ID, error type
and message
- **Filtered alerts**: Skip known errors (InsufficientBalanceError,
ModerationError)
- **Proper error tracking**: Ensure execution_stats.error is set for all
error types

## Technical Implementation

### Rate Limiting Strategy
```python
# Create unique signatures based on function+error+context
error_signature = f"{context}:{func_name}:{type(exception).__name__}:{str(exception)[:100]}"
```
- **5-minute windows**: ALERT_RATE_LIMIT_SECONDS = 300 prevents
duplicate alerts
- **Memory efficient**: Only store last alert timestamp per unique error
signature
- **Context awareness**: Same error in different contexts can send
separate alerts

### Alerting Hierarchy
1. **50 attempts**: Critical alert warning about approaching failure
(EXCESSIVE_RETRY_THRESHOLD)
2. **100 attempts**: Final infrastructure failure (conn_retry max_retry)
3. **Unknown execution errors**: Immediate rate-limited alerts for
unexpected failures

## Files Modified

### Core Implementation
- `backend/executor/manager.py`: Unknown execution error alerts with
rate limiting
- `backend/util/retry.py`: Comprehensive rate-limited alerting
infrastructure
- `backend/util/retry_test.py`: Full test coverage for rate limiting
functionality (14 tests)

### Code Quality Improvements
- **Inlined alert messages**: Eliminated unnecessary temporary variables
- **Simplified logic**: Removed excessive comments and redundant alerts
- **Consistent patterns**: All alert functions follow same clean code
style
- **DRY principle**: Reusable rate-limited alert system for future
monitoring needs

## Benefits

### 🛡️ Prevents Alert Spam
- **Rate limiting**: No more overwhelming Discord channels with
duplicate alerts
- **Intelligent deduplication**: Same errors rate limited while
different errors get through
- **Thread safety**: Concurrent operations handled correctly

### 🔍 Better System Visibility  
- **Unknown errors**: Issues that need investigation are properly
surfaced
- **Infrastructure monitoring**: Early warning for
database/Redis/RabbitMQ issues
- **Rich context**: All necessary debugging information included in
alerts

### 🧹 Maintainable Codebase
- **Reusable infrastructure**: `send_rate_limited_discord_alert()` for
future monitoring
- **Clean, consistent code**: Inlined messages, simplified logic, proper
abstractions
- **Comprehensive testing**: Rate limiting edge cases and real-world
scenarios covered

## Validation Results
-  All 14 retry tests pass including comprehensive rate limiting
coverage
-  Manager execution tests pass validating integration with execution
flow
-  Thread safety validated with concurrent alert attempt tests
-  Real-world scenarios tested including the specific spend_credits
spam issue that motivated this work
-  Code formatting, linting, and type checking all pass

## Before/After Comparison

### Before
- No rate limiting → Discord spam for repeated errors
- Unknown execution errors not monitored → Issues went unnoticed  
- Inconsistent alerting thresholds → Confusing monitoring
- Verbose code with temporary variables → Harder to maintain

### After  
-  Rate-limited intelligent alerting prevents spam
-  Unknown execution errors properly monitored with context
-  Unified 50-attempt threshold for consistent monitoring
-  Clean, maintainable code with reusable infrastructure

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-09 08:22:15 +07:00
Zamil Majdy
c7575dc579 fix(backend): implement rate limiting for critical retry alerts to prevent spam (#11127)
## Summary
Fix the critical issue where retry failure alerts were being spammed
when service communication failed repeatedly.

## Problem
The service communication retry mechanism was sending a critical Discord
alert every time it hit the 50-attempt threshold, with no rate limiting.
This caused alert spam when the same operation (like spend_credits) kept
failing repeatedly with the same error.

## Solution

### Rate Limiting Implementation
- Add ALERT_RATE_LIMIT_SECONDS = 300 (5 minutes) between identical
alerts
- Create _should_send_alert() function with thread-safe rate limiting
using _alert_rate_limiter dict
- Generate unique signatures based on
context:func_name:exception_type:exception_message
- Only send alert if sufficient time has passed since last identical
alert

### Enhanced Logging  
- Rate-limited alerts log as warnings instead of being silently dropped
- Add full exception tracebacks for final failures and every 10th retry
attempt
- Improve alert message clarity and add note about rate limiting
- Better structured logging with exception types and details

### Error Context Preservation
- Maintain all original retry behavior and exception handling
- Preserve critical alerting for genuinely new issues  
- Clean up alert message (removed accidental paste from error logs)

## Technical Details
- Thread-safe implementation using threading.Lock() for rate limiter
access
- Signature includes first 100 chars of exception message for
granularity
- Memory efficient - only stores last alert timestamp per unique error
type
- No breaking changes to existing retry functionality

## Impact
- **Eliminates alert spam**: Same failing operation only alerts once per
5 minutes
- **Preserves critical alerts**: New/different failures still trigger
immediate alerts
- **Better debugging**: Enhanced logging provides full exception context
- **Maintains reliability**: All retry logic works exactly as before

## Testing
-  Rate limiting tested with multiple scenarios
-  Import compatibility verified 
-  No regressions in retry functionality
-  Alert generation works for new vs repeated errors

## Type of Change
- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] This change requires a documentation update

## How Has This Been Tested?
- Manual testing of rate limiting functionality with different error
scenarios
- Import verification to ensure no regressions
- Code formatting and linting compliance

## Checklist
- [x] My code follows the style guidelines of this project
- [x] I have performed a self-review of my own code
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation (N/A -
internal utility)
- [x] My changes generate no new warnings
- [x] Any dependent changes have been merged and published in downstream
modules (N/A)
2025-10-09 05:53:10 +07:00
Ubbe
73603a8ce5 fix(frontend): onboarding re-directs (#11126)
## Changes 🏗️

We weren't awaiting the onboarding enabled check and also we were
re-directing to a wrong onboarding URL.

## Checklist 📋

### For code changes

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Create a new user
  - [x] Re-directs well to onboarding
  - [x] Complete up to Step 5 and logout
  - [x] Login again
  - [x] You are on Step 5  

#### For configuration changes:

None
2025-10-08 15:18:25 +00:00
Ubbe
e562ca37aa fix(frontend): login redirects + onboarding (#11125)
## Changes 🏗️

### Fix re-direct bugs

Sometimes the app will re-direct to a strange URL after login:
```
http://localhost:3000/marketplace,%20/marketplace
```
It looks like a race-condition because the re-direct to `/marketplace`
was done on a [server
action](https://nextjs.org/docs/14/app/building-your-application/data-fetching/server-actions-and-mutations)
rather than in the browser.

** Fixed by** 

Moving the login / signup server actions to Next.js API endpoints. In
this way the login/signup pages just call an API endpoint and handle its
response without having to hassle with serverless 💆🏽

### Wallet layout flash

<img width="800" height="744" alt="Screenshot 2025-10-08 at 22 52 03"
src="https://github.com/user-attachments/assets/7cb85fd5-7dc4-4870-b4e1-173cc8148e51"
/>

The wallet popover would sometimes flash after login, because it was
re-rendering once onboarding and credits data loaded.

** Fixed by** 

Only rendering once we have onboarding and credits data, without the
popover is useless and causes flashes.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Login / Signup to the app with email and Google
  - [x] Works fine
  - [x] Onboarding popover does not flash
  - [x] Onboarding and marketplace re-directs work   

### For configuration changes:

None
2025-10-08 18:35:45 +04:00
Nicholas Tindle
f906fd9298 fix(backend): Allow Project.content to be optional for linear search projects (#11118)
Changed the type of the 'content' field in the Project model to accept
None, making it optional instead of required. Linear doesn't always
return data here if its not set by the user.

<!-- Clearly explain the need for these changes: -->

### Changes 🏗️
- Makes the content optional
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Manually test it works with our data


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Bug Fixes**
- Improved handling of projects with no content by making content
optional.
- Prevents validation errors during project creation, import, or sync
when content is missing.
- Enhances compatibility with integrations that may omit content fields.
- No impact on existing projects with content; behavior remains
unchanged.
  - No user action required.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-10-07 20:04:37 +00:00
seer-by-sentry[bot]
9e79add436 fix(backend): Change progress type to float in Linear Project (#11117)
### Changes 🏗️

- Changed the type of the `progress` field in the `LinearTask` model
from `int` to `float` to fix
[BUILDER-3V5](https://sentry.io/organizations/significant-gravitas/issues/6929150079/).

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Root cause analysis confirms fix -- testing will need to occur in
dev environment before release to prod but this should merge now


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- New Features
- Progress indicators now support decimal values, allowing more precise
tracking (e.g., 42.5% instead of 42%). This enables finer-grained
updates in the interface and any integrations consuming progress data.
- Users may notice smoother progress changes during long-running tasks,
with improved accuracy in percentage displays across relevant views and
APIs.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: seer-by-sentry[bot] <157164994+seer-by-sentry[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2025-10-07 17:59:26 +00:00
Nicholas Tindle
de6f4fca23 Merge branch 'master' into dev 2025-10-07 11:13:38 -05:00
Nicholas Tindle
fb4b8ed9fc feat: track users with sentry on client side (not backend yet) (#11077)
<!-- Clearly explain the need for these changes: -->
We need to be able to count user impact by user count which means we
need to track that
### Changes 🏗️
- Attaches user id to frontend actions (which hopefully propagate to the
backend in some places)
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Test login -> shows on sentry
  - [x] Test logout -> no longer shows on sentry
2025-10-07 15:35:57 +00:00
Zamil Majdy
f3900127d7 feat(backend): instrument prometheus for internal services (#11114)
<!-- Clearly explain the need for these changes: -->

### Changes 🏗️

Instrument Prometheus for internal services

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Existing tests
2025-10-07 21:34:38 +07:00
Ubbe
0978566089 fix(frontend): performance and layout issues (#11036)
## Changes 🏗️

### Performance (Onboarding) 🐎 

- Moved non-UI logic into `providers/onboarding/helpers.ts` to reduce
provider complexity.
- Memoized provider value and narrowed state updates to cut unnecessary
re-renders.
- Deferred non-critical effects until after mount to lower initial JS
work.
 
**Result:** faster initial render and smoother onboarding flows under
load.

### Layout and overflow fixes 📐 

- Replaced `w-screen` with `w-full` in platform/admin/profile layouts
and marketplace wrappers to avoid 100vw scrollbar overflow.
- Adjusted mobile navbar position (`right-0` instead of `-right-4`) to
prevent off-viewport elements.

**Result:** removed horizontal scrolling on Marketplace, Library, and
Settings pages; Build remains unaffected.

### New Generic Error pages

- Standardized global error handling in `app/global-error.tsx` for
consistent display and user feedback.
- Added platform-scoped error page(s) under `app/(platform)/error` for
route-level failures with a consistent layout.
- Improved retry affordances using existing `ErrorCard`.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verify onboarding flows render faster and re-render less (DevTools
flamegraph)
- [x] Confirm no horizontal scrolling on Marketplace, Library, Settings
at common widths
  - [x] Validate mobile navbar stays within viewport
- [x] Trigger errors to confirm global and platform error pages render
consistently

### For configuration changes:

None
2025-10-03 22:41:01 +09:00
247 changed files with 9489 additions and 4790 deletions

View File

@@ -217,9 +217,6 @@ jobs:
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Generate API client
run: pnpm generate:api
- name: Install Browser 'chromium'
run: pnpm playwright install --with-deps chromium

47
autogpt_platform/Makefile Normal file
View File

@@ -0,0 +1,47 @@
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend
# Run just Supabase + Redis + RabbitMQ
start-core:
docker compose up -d deps
# Stop core services
stop-core:
docker compose stop deps
# View logs for core services
logs-core:
docker compose logs -f deps
# Run formatting and linting for backend and frontend
format:
cd backend && poetry run format
cd frontend && pnpm format
cd frontend && pnpm lint
init-env:
cp -n .env.default .env || true
cd backend && cp -n .env.default .env || true
cd frontend && cp -n .env.default .env || true
# Run migrations for backend
migrate:
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
run-backend:
cd backend && poetry run app
run-frontend:
cd frontend && pnpm dev
help:
@echo "Usage: make <target>"
@echo "Targets:"
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
@echo " stop-core - Stop the core services"
@echo " logs-core - Tail the logs for core services"
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
@echo " migrate - Run backend database migrations"
@echo " run-backend - Run the backend FastAPI server"
@echo " run-frontend - Run the frontend Next.js development server"

View File

@@ -38,6 +38,37 @@ To run the AutoGPT Platform, follow these steps:
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Running Just Core services
You can now run the following to enable just the core services.
```
# For help
make help
# Run just Supabase + Redis + RabbitMQ
make start-core
# Stop core services
make stop-core
# View logs from core services
make logs-core
# Run formatting and linting for backend and frontend
make format
# Run migrations for backend database
make migrate
# Run backend server
make run-backend
# Run frontend development server
make run-frontend
```
### Docker Compose Commands
Here are some useful Docker Compose commands for managing your AutoGPT Platform:

View File

@@ -1719,22 +1719,6 @@ files = [
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
strenum = ">=0.4.15,<0.5.0"
[[package]]
name = "tenacity"
version = "9.1.2"
description = "Retry code until it succeeds"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138"},
{file = "tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb"},
]
[package.extras]
doc = ["reno", "sphinx"]
test = ["pytest", "tornado (>=4.5)", "typeguard"]
[[package]]
name = "tomli"
version = "2.2.1"
@@ -1945,4 +1929,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "5ec9e6cd2ef7524a356586354755215699e7b37b9bbdfbabc9c73b43085915f4"
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"

View File

@@ -19,7 +19,6 @@ pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
redis = "^6.2.0"
supabase = "^2.16.0"
tenacity = "^9.1.2"
uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies]

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
T = TypeVar("T")
@cached(ttl_seconds=3600) # Cache blocks for 1 hour
@cached(ttl_seconds=3600)
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
from backend.util.settings import Config

View File

@@ -66,6 +66,7 @@ class AddToDictionaryBlock(Block):
dictionary: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
advanced=False,
)
key: str = SchemaField(
default="",

View File

@@ -171,11 +171,11 @@ class SendDiscordMessageBlock(Block):
description="The content of the message to send"
)
channel_name: str = SchemaField(
description="The name of the channel the message will be sent to"
description="Channel ID or channel name to send the message to"
)
server_name: str = SchemaField(
description="The name of the server where the channel is located",
advanced=True, # Optional field for server name
description="Server name (only needed if using channel name)",
advanced=True,
default="",
)
@@ -231,25 +231,49 @@ class SendDiscordMessageBlock(Block):
@client.event
async def on_ready():
print(f"Logged in as {client.user}")
for guild in client.guilds:
if server_name and guild.name != server_name:
continue
for channel in guild.text_channels:
if channel.name == channel_name:
# Split message into chunks if it exceeds 2000 characters
chunks = self.chunk_message(message_content)
last_message = None
for chunk in chunks:
last_message = await channel.send(chunk)
result["status"] = "Message sent"
result["message_id"] = (
str(last_message.id) if last_message else ""
)
result["channel_id"] = str(channel.id)
await client.close()
return
channel = None
result["status"] = "Channel not found"
# Try to parse as channel ID first
try:
channel_id = int(channel_name)
channel = client.get_channel(channel_id)
except ValueError:
# Not a valid ID, will try name lookup
pass
# If not found by ID (or not an ID), try name lookup
if not channel:
for guild in client.guilds:
if server_name and guild.name != server_name:
continue
for ch in guild.text_channels:
if ch.name == channel_name:
channel = ch
break
if channel:
break
if not channel:
result["status"] = f"Channel not found: {channel_name}"
await client.close()
return
# Type check - ensure it's a text channel that can send messages
if not hasattr(channel, "send"):
result["status"] = (
f"Channel {channel_name} cannot receive messages (not a text channel)"
)
await client.close()
return
# Split message into chunks if it exceeds 2000 characters
chunks = self.chunk_message(message_content)
last_message = None
for chunk in chunks:
last_message = await channel.send(chunk) # type: ignore
result["status"] = "Message sent"
result["message_id"] = str(last_message.id) if last_message else ""
result["channel_id"] = str(channel.id)
await client.close()
await client.start(token)

View File

@@ -2,7 +2,7 @@ from typing import Any
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.json import json
from backend.util.json import loads
class StepThroughItemsBlock(Block):
@@ -68,7 +68,7 @@ class StepThroughItemsBlock(Block):
raise ValueError(
f"Input too large: {len(data)} bytes > {MAX_ITEM_SIZE} bytes"
)
items = json.loads(data)
items = loads(data)
else:
items = data

View File

@@ -1,5 +1,8 @@
from typing import List
from urllib.parse import quote
from typing_extensions import TypedDict
from backend.blocks.jina._auth import (
JinaCredentials,
JinaCredentialsField,
@@ -10,6 +13,12 @@ from backend.data.model import SchemaField
from backend.util.request import Requests
class Reference(TypedDict):
url: str
keyQuote: str
isSupportive: bool
class FactCheckerBlock(Block):
class Input(BlockSchema):
statement: str = SchemaField(
@@ -23,6 +32,10 @@ class FactCheckerBlock(Block):
)
result: bool = SchemaField(description="The result of the factuality check")
reason: str = SchemaField(description="The reason for the factuality result")
references: List[Reference] = SchemaField(
description="List of references supporting or contradicting the statement",
default=[],
)
error: str = SchemaField(description="Error message if the check fails")
def __init__(self):
@@ -53,5 +66,11 @@ class FactCheckerBlock(Block):
yield "factuality", data["factuality"]
yield "result", data["result"]
yield "reason", data["reason"]
# Yield references if present in the response
if "references" in data:
yield "references", data["references"]
else:
yield "references", []
else:
raise RuntimeError(f"Expected 'data' key not found in response: {data}")

View File

@@ -37,5 +37,5 @@ class Project(BaseModel):
name: str
description: str
priority: int
progress: int
content: str
progress: float
content: str | None

View File

@@ -102,6 +102,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
@@ -217,6 +218,9 @@ MODEL_METADATA = {
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
"anthropic", 200000, 64000
), # claude-sonnet-4-5-20250929
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000
), # claude-3-7-sonnet-20250219

View File

@@ -0,0 +1,226 @@
# flake8: noqa: E501
import logging
from enum import Enum
from typing import Any, Literal
import openai
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
class PerplexityModel(str, Enum):
"""Perplexity sonar models available via OpenRouter"""
SONAR = "perplexity/sonar"
SONAR_PRO = "perplexity/sonar-pro"
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
PerplexityCredentials = CredentialsMetaInput[
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
]
TEST_CREDENTIALS = APIKeyCredentials(
id="test-perplexity-creds",
provider="open_router",
api_key=SecretStr("mock-openrouter-api-key"),
title="Mock OpenRouter API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def PerplexityCredentialsField() -> PerplexityCredentials:
return CredentialsField(
description="OpenRouter API key for accessing Perplexity models.",
)
class PerplexityBlock(Block):
class Input(BlockSchema):
prompt: str = SchemaField(
description="The query to send to the Perplexity model.",
placeholder="Enter your query here...",
)
model: PerplexityModel = SchemaField(
title="Perplexity Model",
default=PerplexityModel.SONAR,
description="The Perplexity sonar model to use.",
advanced=False,
)
credentials: PerplexityCredentials = PerplexityCredentialsField()
system_prompt: str = SchemaField(
title="System Prompt",
default="",
description="Optional system prompt to provide context to the model.",
advanced=True,
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate.",
)
class Output(BlockSchema):
response: str = SchemaField(
description="The response from the Perplexity model."
)
annotations: list[dict[str, Any]] = SchemaField(
description="List of URL citations and annotations from the response."
)
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
super().__init__(
id="c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f",
description="Query Perplexity's sonar models with real-time web search capabilities and receive annotated responses with source citations.",
categories={BlockCategory.AI, BlockCategory.SEARCH},
input_schema=PerplexityBlock.Input,
output_schema=PerplexityBlock.Output,
test_input={
"prompt": "What is the weather today?",
"model": PerplexityModel.SONAR,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("response", "The weather varies by location..."),
("annotations", list),
],
test_mock={
"call_perplexity": lambda *args, **kwargs: {
"response": "The weather varies by location...",
"annotations": [
{
"type": "url_citation",
"url_citation": {
"title": "weather.com",
"url": "https://weather.com",
},
}
],
}
},
)
self.execution_stats = NodeExecutionStats()
async def call_perplexity(
self,
credentials: APIKeyCredentials,
model: PerplexityModel,
prompt: str,
system_prompt: str = "",
max_tokens: int | None = None,
) -> dict[str, Any]:
"""Call Perplexity via OpenRouter and extract annotations."""
client = openai.AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=credentials.api_key.get_secret_value(),
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
response = await client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=model.value,
messages=messages,
max_tokens=max_tokens,
)
if not response.choices:
raise ValueError("No response from Perplexity via OpenRouter.")
# Extract the response content
response_content = response.choices[0].message.content or ""
# Extract annotations if present in the message
annotations = []
if hasattr(response.choices[0].message, "annotations"):
# If annotations are directly available
annotations = response.choices[0].message.annotations
else:
# Check if there's a raw response with annotations
raw = getattr(response.choices[0].message, "_raw_response", None)
if isinstance(raw, dict) and "annotations" in raw:
annotations = raw["annotations"]
if not annotations and hasattr(response, "model_extra"):
# Check model_extra for annotations
model_extra = response.model_extra
if isinstance(model_extra, dict):
# Check in choices
if "choices" in model_extra and len(model_extra["choices"]) > 0:
choice = model_extra["choices"][0]
if "message" in choice and "annotations" in choice["message"]:
annotations = choice["message"]["annotations"]
# Also check the raw response object for annotations
if not annotations:
raw = getattr(response, "_raw_response", None)
if isinstance(raw, dict):
# Check various possible locations for annotations
if "annotations" in raw:
annotations = raw["annotations"]
elif "choices" in raw and len(raw["choices"]) > 0:
choice = raw["choices"][0]
if "message" in choice and "annotations" in choice["message"]:
annotations = choice["message"]["annotations"]
# Update execution stats
if response.usage:
self.execution_stats.input_token_count = response.usage.prompt_tokens
self.execution_stats.output_token_count = (
response.usage.completion_tokens
)
return {"response": response_content, "annotations": annotations or []}
except Exception as e:
logger.error(f"Error calling Perplexity: {e}")
raise
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
logger.debug(f"Running Perplexity block with model: {input_data.model}")
try:
result = await self.call_perplexity(
credentials=credentials,
model=input_data.model,
prompt=input_data.prompt,
system_prompt=input_data.system_prompt,
max_tokens=input_data.max_tokens,
)
yield "response", result["response"]
yield "annotations", result["annotations"]
except Exception as e:
error_msg = f"Error calling Perplexity: {str(e)}"
logger.error(error_msg)
yield "error", error_msg

View File

@@ -69,6 +69,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_5_SONNET: 4,

View File

@@ -5,7 +5,6 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast
import stripe
from prisma import Json
from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
@@ -13,16 +12,12 @@ from prisma.enums import (
OnboardingStep,
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
CreditTransactionWhereInput,
)
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
from pydantic import BaseModel
from backend.data import db
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.db import query_raw_with_schema
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
from backend.data.model import (
AutoTopUpConfig,
@@ -36,7 +31,8 @@ from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import SafeJson
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.json import SafeJson, dumps
from backend.util.models import Pagination
from backend.util.retry import func_retry
from backend.util.settings import Settings
@@ -49,6 +45,10 @@ stripe.api_key = settings.secrets.stripe_api_key
logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
# Constants for test compatibility
POSTGRES_INT_MAX = 2147483647
POSTGRES_INT_MIN = -2147483648
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
@@ -139,14 +139,20 @@ class UserCreditBase(ABC):
pass
@abstractmethod
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
async def onboarding_reward(
self, user_id: str, credits: int, step: OnboardingStep
) -> bool:
"""
Reward the user with credits for completing an onboarding step.
Won't reward if the user has already received credits for the step.
Args:
user_id (str): The user ID.
credits (int): The amount to reward.
step (OnboardingStep): The onboarding step.
Returns:
bool: True if rewarded, False if already rewarded.
"""
pass
@@ -236,6 +242,12 @@ class UserCreditBase(ABC):
"""
Returns the current balance of the user & the latest balance snapshot time.
"""
# Check UserBalance first for efficiency and consistency
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
if user_balance:
return user_balance.balance, user_balance.updatedAt
# Fallback to transaction history computation if UserBalance doesn't exist
top_time = self.time_now()
snapshot = await CreditTransaction.prisma().find_first(
where={
@@ -250,72 +262,86 @@ class UserCreditBase(ABC):
snapshot_balance = snapshot.runningBalance or 0 if snapshot else 0
snapshot_time = snapshot.createdAt if snapshot else datetime_min
# Get transactions after the snapshot, this should not exist, but just in case.
transactions = await CreditTransaction.prisma().group_by(
by=["userId"],
sum={"amount": True},
max={"createdAt": True},
where={
"userId": user_id,
"createdAt": {
"gt": snapshot_time,
"lte": top_time,
},
"isActive": True,
},
)
transaction_balance = (
int(transactions[0].get("_sum", {}).get("amount", 0) + snapshot_balance)
if transactions
else snapshot_balance
)
transaction_time = (
datetime.fromisoformat(
str(transactions[0].get("_max", {}).get("createdAt", datetime_min))
)
if transactions
else snapshot_time
)
return transaction_balance, transaction_time
return snapshot_balance, snapshot_time
@func_retry
async def _enable_transaction(
self,
transaction_key: str,
user_id: str,
metadata: Json,
metadata: SafeJson,
new_transaction_key: str | None = None,
):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
# First check if transaction exists and is inactive (safety check)
transaction = await CreditTransaction.prisma().find_first(
where={
"transactionKey": transaction_key,
"userId": user_id,
"isActive": False,
}
)
if transaction.isActive:
return
if not transaction:
# Transaction doesn't exist or is already active, return early
return None
async with db.locked_transaction(f"usr_trx_{user_id}"):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
# Atomic operation to enable transaction and update user balance using UserBalance
result = await query_raw_with_schema(
"""
WITH user_balance_lock AS (
SELECT
$2::text as userId,
COALESCE(
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $2 FOR UPDATE),
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
(SELECT COALESCE(ct."runningBalance", 0)
FROM {schema_prefix}"CreditTransaction" ct
WHERE ct."userId" = $2
AND ct."isActive" = true
AND ct."runningBalance" IS NOT NULL
ORDER BY ct."createdAt" DESC
LIMIT 1),
0
) as balance
),
transaction_check AS (
SELECT * FROM {schema_prefix}"CreditTransaction"
WHERE "transactionKey" = $1 AND "userId" = $2 AND "isActive" = false
),
balance_update AS (
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
SELECT
$2::text,
user_balance_lock.balance + transaction_check.amount,
CURRENT_TIMESTAMP
FROM user_balance_lock, transaction_check
ON CONFLICT ("userId") DO UPDATE SET
"balance" = EXCLUDED."balance",
"updatedAt" = EXCLUDED."updatedAt"
RETURNING "balance", "updatedAt"
),
transaction_update AS (
UPDATE {schema_prefix}"CreditTransaction"
SET "transactionKey" = COALESCE($4, $1),
"isActive" = true,
"runningBalance" = balance_update.balance,
"createdAt" = balance_update."updatedAt",
"metadata" = $3::jsonb
FROM balance_update, transaction_check
WHERE {schema_prefix}"CreditTransaction"."transactionKey" = transaction_check."transactionKey"
AND {schema_prefix}"CreditTransaction"."userId" = transaction_check."userId"
RETURNING {schema_prefix}"CreditTransaction"."runningBalance"
)
if transaction.isActive:
return
SELECT "runningBalance" as balance FROM transaction_update;
""",
transaction_key, # $1
user_id, # $2
dumps(metadata.data), # $3 - use pre-serialized JSON string for JSONB
new_transaction_key, # $4
)
user_balance, _ = await self._get_credits(user_id)
await CreditTransaction.prisma().update(
where={
"creditTransactionIdentifier": {
"transactionKey": transaction_key,
"userId": user_id,
}
},
data={
"transactionKey": new_transaction_key or transaction_key,
"isActive": True,
"runningBalance": user_balance + transaction.amount,
"createdAt": self.time_now(),
"metadata": metadata,
},
)
if result:
# UserBalance is already updated by the CTE
return result[0]["balance"]
async def _add_transaction(
self,
@@ -326,12 +352,54 @@ class UserCreditBase(ABC):
transaction_key: str | None = None,
ceiling_balance: int | None = None,
fail_insufficient_credits: bool = True,
metadata: Json = SafeJson({}),
metadata: SafeJson = SafeJson({}),
) -> tuple[int, str]:
"""
Add a new transaction for the user.
This is the only method that should be used to add a new transaction.
ATOMIC OPERATION DESIGN DECISION:
================================
This method uses PostgreSQL row-level locking (FOR UPDATE) for atomic credit operations.
After extensive analysis of concurrency patterns and correctness requirements, we determined
that the FOR UPDATE approach is necessary despite the latency overhead.
WHY FOR UPDATE LOCKING IS REQUIRED:
----------------------------------
1. **Data Consistency**: Credit operations must be ACID-compliant. The balance check,
calculation, and update must be atomic to prevent race conditions where:
- Multiple spend operations could exceed available balance
- Lost update problems could occur with concurrent top-ups
- Refunds could create negative balances incorrectly
2. **Serializability**: FOR UPDATE ensures operations are serialized at the database level,
guaranteeing that each transaction sees a consistent view of the balance before applying changes.
3. **Correctness Over Performance**: Financial operations require absolute correctness.
The ~10-50ms latency increase from row locking is acceptable for the guarantee that
no user will ever have an incorrect balance due to race conditions.
4. **PostgreSQL Optimization**: Modern PostgreSQL versions optimize row locks efficiently.
The performance cost is minimal compared to the complexity and risk of lock-free approaches.
ALTERNATIVES CONSIDERED AND REJECTED:
------------------------------------
- **Optimistic Concurrency**: Using version numbers or timestamps would require complex
retry logic and could still fail under high contention scenarios.
- **Application-Level Locking**: Redis locks or similar would add network overhead and
single points of failure while being less reliable than database locks.
- **Event Sourcing**: Would require complete architectural changes and eventual consistency
models that don't fit our real-time balance requirements.
PERFORMANCE CHARACTERISTICS:
---------------------------
- Single user operations: 10-50ms latency (acceptable for financial operations)
- Concurrent operations on same user: Serialized (prevents data corruption)
- Concurrent operations on different users: Fully parallel (no blocking)
This design prioritizes correctness and data integrity over raw performance,
which is the appropriate choice for a credit/payment system.
Args:
user_id (str): The user ID.
amount (int): The amount of credits to add.
@@ -345,40 +413,142 @@ class UserCreditBase(ABC):
Returns:
tuple[int, str]: The new balance & the transaction key.
"""
async with db.locked_transaction(f"usr_trx_{user_id}"):
# Get latest balance snapshot
user_balance, _ = await self._get_credits(user_id)
if ceiling_balance and amount > 0 and user_balance >= ceiling_balance:
# Quick validation for ceiling balance to avoid unnecessary database operations
if ceiling_balance and amount > 0:
current_balance, _ = await self._get_credits(user_id)
if current_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${user_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
)
if amount < 0 and user_balance + amount < 0:
if fail_insufficient_credits:
raise InsufficientBalanceError(
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=user_balance,
amount=amount,
# Single unified atomic operation for all transaction types using UserBalance
try:
result = await query_raw_with_schema(
"""
WITH user_balance_lock AS (
SELECT
$1::text as userId,
-- CRITICAL: FOR UPDATE lock prevents concurrent modifications to the same user's balance
-- This ensures atomic read-modify-write operations and prevents race conditions
COALESCE(
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $1 FOR UPDATE),
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
(SELECT COALESCE(ct."runningBalance", 0)
FROM {schema_prefix}"CreditTransaction" ct
WHERE ct."userId" = $1
AND ct."isActive" = true
AND ct."runningBalance" IS NOT NULL
ORDER BY ct."createdAt" DESC
LIMIT 1),
0
) as balance
),
balance_update AS (
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
SELECT
$1::text,
CASE
-- For inactive transactions: Don't update balance
WHEN $5::boolean = false THEN user_balance_lock.balance
-- For ceiling balance (amount > 0): Apply ceiling
WHEN $2 > 0 AND $7::int IS NOT NULL AND user_balance_lock.balance > $7::int - $2 THEN $7::int
-- For regular operations: Apply with overflow/underflow protection
WHEN user_balance_lock.balance + $2 > $6::int THEN $6::int
WHEN user_balance_lock.balance + $2 < $10::int THEN $10::int
ELSE user_balance_lock.balance + $2
END,
CURRENT_TIMESTAMP
FROM user_balance_lock
WHERE (
$5::boolean = false OR -- Allow inactive transactions
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
$8::boolean = false OR -- Allow when insufficient balance check is disabled
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
)
ON CONFLICT ("userId") DO UPDATE SET
"balance" = EXCLUDED."balance",
"updatedAt" = EXCLUDED."updatedAt"
RETURNING "balance", "updatedAt"
),
transaction_insert AS (
INSERT INTO {schema_prefix}"CreditTransaction" (
"userId", "amount", "type", "runningBalance",
"metadata", "isActive", "createdAt", "transactionKey"
)
SELECT
$1::text,
$2::int,
$3::text::{schema_prefix}"CreditTransactionType",
CASE
-- For inactive transactions: Set runningBalance to original balance (don't apply the change yet)
WHEN $5::boolean = false THEN user_balance_lock.balance
ELSE COALESCE(balance_update.balance, user_balance_lock.balance)
END,
$4::jsonb,
$5::boolean,
COALESCE(balance_update."updatedAt", CURRENT_TIMESTAMP),
COALESCE($9, gen_random_uuid()::text)
FROM user_balance_lock
LEFT JOIN balance_update ON true
WHERE (
$5::boolean = false OR -- Allow inactive transactions
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
$8::boolean = false OR -- Allow when insufficient balance check is disabled
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
)
RETURNING "runningBalance", "transactionKey"
)
SELECT "runningBalance" as balance, "transactionKey" FROM transaction_insert;
""",
user_id, # $1
amount, # $2
transaction_type.value, # $3
dumps(metadata.data), # $4 - use pre-serialized JSON string for JSONB
is_active, # $5
POSTGRES_INT_MAX, # $6 - overflow protection
ceiling_balance, # $7 - ceiling balance (nullable)
fail_insufficient_credits, # $8 - check balance for spending
transaction_key, # $9 - transaction key (nullable)
POSTGRES_INT_MIN, # $10 - underflow protection
)
except Exception as e:
# Convert raw SQL unique constraint violations to UniqueViolationError
# for consistent exception handling throughout the codebase
error_str = str(e).lower()
if (
"already exists" in error_str
or "duplicate key" in error_str
or "unique constraint" in error_str
):
# Extract table and constraint info for better error messages
# Re-raise as a UniqueViolationError but with proper format
# Create a minimal data structure that the error constructor expects
raise UniqueViolationError({"error": str(e), "user_facing_error": {}})
# For any other error, re-raise as-is
raise
amount = min(-user_balance, 0)
if result:
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
# UserBalance is already updated by the CTE
return new_balance, tx_key
# Create the transaction
transaction_data: CreditTransactionCreateInput = {
"userId": user_id,
"amount": amount,
"runningBalance": user_balance + amount,
"type": transaction_type,
"metadata": metadata,
"isActive": is_active,
"createdAt": self.time_now(),
}
if transaction_key:
transaction_data["transactionKey"] = transaction_key
tx = await CreditTransaction.prisma().create(data=transaction_data)
return user_balance + amount, tx.transactionKey
# If no result, either user doesn't exist or insufficient balance
user = await User.prisma().find_unique(where={"id": user_id})
if not user:
raise ValueError(f"User {user_id} not found")
# Must be insufficient balance for spending operation
if amount < 0 and fail_insufficient_credits:
current_balance, _ = await self._get_credits(user_id)
raise InsufficientBalanceError(
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=current_balance,
amount=amount,
)
# Unexpected case
raise ValueError(f"Transaction failed for user {user_id}, amount {amount}")
class UserCredit(UserCreditBase):
@@ -453,9 +623,10 @@ class UserCredit(UserCreditBase):
{"reason": f"Reward for completing {step.value} onboarding step."}
),
)
return True
except UniqueViolationError:
# Already rewarded for this step
pass
# User already received this reward
return False
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
@@ -644,7 +815,7 @@ class UserCredit(UserCreditBase):
):
# init metadata, without sharing it with the world
metadata = metadata or {}
if not metadata["reason"]:
if not metadata.get("reason"):
match top_up_type:
case TopUpType.MANUAL:
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
@@ -974,8 +1145,8 @@ class DisabledUserCredit(UserCreditBase):
async def top_up_credits(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs) -> bool:
return True
async def top_up_intent(self, *args, **kwargs) -> str:
return ""
@@ -993,14 +1164,31 @@ class DisabledUserCredit(UserCreditBase):
pass
def get_user_credit_model() -> UserCreditBase:
async def get_user_credit_model(user_id: str) -> UserCreditBase:
"""
Get the credit model for a user, considering LaunchDarkly flags.
Args:
user_id (str): The user ID to check flags for.
Returns:
UserCreditBase: The appropriate credit model for the user
"""
if not settings.config.enable_credit:
return DisabledUserCredit()
if settings.config.enable_beta_monthly_credit:
return BetaUserCredit(settings.config.num_user_credits_refill)
# Check LaunchDarkly flag for payment pilot users
# Default to False (beta monthly credit behavior) to maintain current behavior
is_payment_enabled = await is_feature_enabled(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
return UserCredit()
if is_payment_enabled:
# Payment enabled users get UserCredit (no monthly refills, enable payments)
return UserCredit()
else:
# Default behavior: users get beta monthly credits
return BetaUserCredit(settings.config.num_user_credits_refill)
def get_block_costs() -> dict[str, list["BlockCost"]]:
@@ -1090,7 +1278,8 @@ async def admin_get_user_history(
)
reason = metadata.get("reason", "No reason provided")
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
user_credit_model = await get_user_credit_model(tx.userId)
balance, _ = await user_credit_model._get_credits(tx.userId)
history.append(
UserTransaction(

View File

@@ -0,0 +1,172 @@
"""
Test ceiling balance functionality to ensure auto top-up limits work correctly.
This test was added to cover a previously untested code path that could lead to
incorrect balance capping behavior.
"""
from uuid import uuid4
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
async def create_test_user(user_id: str) -> None:
"""Create a test user for ceiling tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their transactions."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_ceiling_balance_rejects_when_above_threshold(server: SpinTestServer):
"""Test that ceiling balance correctly rejects top-ups when balance is above threshold."""
credit_system = UserCredit()
user_id = f"ceiling-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user balance of 1000 ($10) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
current_balance = await credit_system.get_credits(user_id)
assert current_balance == 1000
# Try to add 200 more with ceiling of 800 (should reject since 1000 > 800)
with pytest.raises(ValueError, match="You already have enough balance"):
await credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
ceiling_balance=800, # Ceiling lower than current balance
)
# Balance should remain unchanged
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 1000, f"Balance should remain 1000, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_ceiling_balance_clamps_when_would_exceed(server: SpinTestServer):
"""Test that ceiling balance correctly clamps amounts that would exceed the ceiling."""
credit_system = UserCredit()
user_id = f"ceiling-clamp-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user balance of 500 ($5) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Add 800 more with ceiling of 1000 (should clamp to 1000, not reach 1300)
final_balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=800,
transaction_type=CreditTransactionType.TOP_UP,
ceiling_balance=1000, # Ceiling should clamp 500 + 800 = 1300 to 1000
)
# Balance should be clamped to ceiling
assert (
final_balance == 1000
), f"Balance should be clamped to 1000, got {final_balance}"
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == 1000
), f"Stored balance should be 1000, got {stored_balance}"
# Verify transaction shows the clamped amount
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": CreditTransactionType.TOP_UP},
order={"createdAt": "desc"},
)
# Should have 2 transactions: 500 + (500 to reach ceiling of 1000)
assert len(transactions) == 2
# The second transaction should show it only added 500, not 800
second_tx = transactions[0] # Most recent
assert second_tx.runningBalance == 1000
# The actual amount recorded could be 800 (what was requested) but balance was clamped
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_ceiling_balance_allows_when_under_threshold(server: SpinTestServer):
"""Test that ceiling balance allows top-ups when balance is under threshold."""
credit_system = UserCredit()
user_id = f"ceiling-under-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user balance of 300 ($3) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=300,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Add 200 more with ceiling of 1000 (should succeed: 300 + 200 = 500 < 1000)
final_balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
ceiling_balance=1000,
)
# Balance should be exactly 500
assert final_balance == 500, f"Balance should be 500, got {final_balance}"
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == 500
), f"Stored balance should be 500, got {stored_balance}"
finally:
await cleanup_test_user(user_id)

View File

@@ -0,0 +1,737 @@
"""
Concurrency and atomicity tests for the credit system.
These tests ensure the credit system handles high-concurrency scenarios correctly
without race conditions, deadlocks, or inconsistent state.
"""
import asyncio
import random
from uuid import uuid4
import prisma.enums
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
# Test with both UserCredit and BetaUserCredit if needed
credit_system = UserCredit()
async def create_test_user(user_id: str) -> None:
"""Create a test user with initial balance."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
# Ensure UserBalance record exists
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their transactions."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_spends_same_user(server: SpinTestServer):
"""Test multiple concurrent spends from the same user don't cause race conditions."""
user_id = f"concurrent-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user initial balance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Try to spend 10 x $1 concurrently
async def spend_one_dollar(idx: int):
try:
return await credit_system.spend_credits(
user_id,
100, # $1
UsageTransactionMetadata(
graph_exec_id=f"concurrent-{idx}",
reason=f"Concurrent spend {idx}",
),
)
except InsufficientBalanceError:
return None
# Run 10 concurrent spends
results = await asyncio.gather(
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
)
# Count successful spends
successful = [
r for r in results if r is not None and not isinstance(r, Exception)
]
failed = [r for r in results if isinstance(r, InsufficientBalanceError)]
# All 10 should succeed since we have exactly $10
assert len(successful) == 10, f"Expected 10 successful, got {len(successful)}"
assert len(failed) == 0, f"Expected 0 failures, got {len(failed)}"
# Final balance should be exactly 0
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
# Verify transaction history is consistent
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE}
)
assert (
len(transactions) == 10
), f"Expected 10 transactions, got {len(transactions)}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_spends_insufficient_balance(server: SpinTestServer):
"""Test that concurrent spends correctly enforce balance limits."""
user_id = f"insufficient-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user limited balance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "limited_balance"}),
)
# Try to spend 10 x $1 concurrently (but only have $5)
async def spend_one_dollar(idx: int):
try:
return await credit_system.spend_credits(
user_id,
100, # $1
UsageTransactionMetadata(
graph_exec_id=f"insufficient-{idx}",
reason=f"Insufficient spend {idx}",
),
)
except InsufficientBalanceError:
return "FAILED"
# Run 10 concurrent spends
results = await asyncio.gather(
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
)
# Count successful vs failed
successful = [
r
for r in results
if r not in ["FAILED", None] and not isinstance(r, Exception)
]
failed = [r for r in results if r == "FAILED"]
# Exactly 5 should succeed, 5 should fail
assert len(successful) == 5, f"Expected 5 successful, got {len(successful)}"
assert len(failed) == 5, f"Expected 5 failures, got {len(failed)}"
# Final balance should be exactly 0
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_mixed_operations(server: SpinTestServer):
"""Test concurrent mix of spends, top-ups, and balance checks."""
user_id = f"mixed-test-{uuid4()}"
await create_test_user(user_id)
try:
# Initial balance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Mix of operations
async def mixed_operations():
operations = []
# 5 spends of $1 each
for i in range(5):
operations.append(
credit_system.spend_credits(
user_id,
100,
UsageTransactionMetadata(reason=f"Mixed spend {i}"),
)
)
# 3 top-ups of $2 each using internal method
for i in range(3):
operations.append(
credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": f"concurrent_topup_{i}"}),
)
)
# 10 balance checks
for i in range(10):
operations.append(credit_system.get_credits(user_id))
return await asyncio.gather(*operations, return_exceptions=True)
results = await mixed_operations()
# Check no exceptions occurred
exceptions = [
r
for r in results
if isinstance(r, Exception) and not isinstance(r, InsufficientBalanceError)
]
assert len(exceptions) == 0, f"Unexpected exceptions: {exceptions}"
# Final balance should be: 1000 - 500 + 600 = 1100
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 1100, f"Expected balance 1100, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_race_condition_exact_balance(server: SpinTestServer):
"""Test spending exact balance amount concurrently doesn't go negative."""
user_id = f"exact-balance-{uuid4()}"
await create_test_user(user_id)
try:
# Give exact amount using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=100,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "exact_amount"}),
)
# Try to spend $1 twice concurrently
async def spend_exact():
try:
return await credit_system.spend_credits(
user_id, 100, UsageTransactionMetadata(reason="Exact spend")
)
except InsufficientBalanceError:
return "FAILED"
# Both try to spend the full balance
result1, result2 = await asyncio.gather(spend_exact(), spend_exact())
# Exactly one should succeed
results = [result1, result2]
successful = [
r for r in results if r != "FAILED" and not isinstance(r, Exception)
]
failed = [r for r in results if r == "FAILED"]
assert len(successful) == 1, f"Expected 1 success, got {len(successful)}"
assert len(failed) == 1, f"Expected 1 failure, got {len(failed)}"
# Balance should be exactly 0, never negative
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_onboarding_reward_idempotency(server: SpinTestServer):
"""Test that onboarding rewards are idempotent (can't be claimed twice)."""
user_id = f"onboarding-test-{uuid4()}"
await create_test_user(user_id)
try:
# Use WELCOME step which is defined in the OnboardingStep enum
# Try to claim same reward multiple times concurrently
async def claim_reward():
try:
result = await credit_system.onboarding_reward(
user_id, 500, prisma.enums.OnboardingStep.WELCOME
)
return "SUCCESS" if result else "DUPLICATE"
except Exception as e:
print(f"Claim reward failed: {e}")
return "FAILED"
# Try 5 concurrent claims of the same reward
results = await asyncio.gather(*[claim_reward() for _ in range(5)])
# Count results
success_count = results.count("SUCCESS")
failed_count = results.count("FAILED")
# At least one should succeed, others should be duplicates
assert success_count >= 1, f"At least one claim should succeed, got {results}"
assert failed_count == 0, f"No claims should fail, got {results}"
# Check balance - should only have 500, not 2500
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 500, f"Expected balance 500, got {final_balance}"
# Check only one transaction exists
transactions = await CreditTransaction.prisma().find_many(
where={
"userId": user_id,
"type": prisma.enums.CreditTransactionType.GRANT,
"transactionKey": f"REWARD-{user_id}-WELCOME",
}
)
assert (
len(transactions) == 1
), f"Expected 1 reward transaction, got {len(transactions)}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_integer_overflow_protection(server: SpinTestServer):
"""Test that integer overflow is prevented by clamping to POSTGRES_INT_MAX."""
user_id = f"overflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Try to add amount that would overflow
max_int = POSTGRES_INT_MAX
# First, set balance near max
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": max_int - 100},
"update": {"balance": max_int - 100},
},
)
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
await credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "overflow_protection"}),
)
# Balance should be clamped to max_int, not overflowed
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == max_int
), f"Balance should be clamped to {max_int}, got {final_balance}"
# Verify transaction was created with clamped amount
transactions = await CreditTransaction.prisma().find_many(
where={
"userId": user_id,
"type": prisma.enums.CreditTransactionType.TOP_UP,
},
order={"createdAt": "desc"},
)
assert len(transactions) > 0, "Transaction should be created"
assert (
transactions[0].runningBalance == max_int
), "Transaction should show clamped balance"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_high_concurrency_stress(server: SpinTestServer):
"""Stress test with many concurrent operations."""
user_id = f"stress-test-{uuid4()}"
await create_test_user(user_id)
try:
# Initial balance using internal method (bypasses Stripe)
initial_balance = 10000 # $100
await credit_system._add_transaction(
user_id=user_id,
amount=initial_balance,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "stress_test_balance"}),
)
# Run many concurrent operations
async def random_operation(idx: int):
operation = random.choice(["spend", "check"])
if operation == "spend":
amount = random.randint(1, 50) # $0.01 to $0.50
try:
return (
"spend",
amount,
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(reason=f"Stress {idx}"),
),
)
except InsufficientBalanceError:
return ("spend_failed", amount, None)
else:
balance = await credit_system.get_credits(user_id)
return ("check", 0, balance)
# Run 100 concurrent operations
results = await asyncio.gather(
*[random_operation(i) for i in range(100)], return_exceptions=True
)
# Calculate expected final balance
total_spent = sum(
r[1]
for r in results
if not isinstance(r, Exception) and isinstance(r, tuple) and r[0] == "spend"
)
expected_balance = initial_balance - total_spent
# Verify final balance
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == expected_balance
), f"Expected {expected_balance}, got {final_balance}"
assert final_balance >= 0, "Balance went negative!"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestServer):
"""Test multiple concurrent spends when there's sufficient balance for all."""
user_id = f"multi-spend-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user 150 balance ($1.50) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=150,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "sufficient_balance"}),
)
# Track individual timing to see serialization
timings = {}
async def spend_with_detailed_timing(amount: int, label: str):
start = asyncio.get_event_loop().time()
try:
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(
graph_exec_id=f"concurrent-{label}",
reason=f"Concurrent spend {label}",
),
)
end = asyncio.get_event_loop().time()
timings[label] = {"start": start, "end": end, "duration": end - start}
return f"{label}-SUCCESS"
except Exception as e:
end = asyncio.get_event_loop().time()
timings[label] = {
"start": start,
"end": end,
"duration": end - start,
"error": str(e),
}
return f"{label}-FAILED: {e}"
# Run concurrent spends: 10, 20, 30 (total 60, well under 150)
overall_start = asyncio.get_event_loop().time()
results = await asyncio.gather(
spend_with_detailed_timing(10, "spend-10"),
spend_with_detailed_timing(20, "spend-20"),
spend_with_detailed_timing(30, "spend-30"),
return_exceptions=True,
)
overall_end = asyncio.get_event_loop().time()
print(f"Results: {results}")
print(f"Overall duration: {overall_end - overall_start:.4f}s")
# Analyze timing to detect serialization vs true concurrency
print("\nTiming analysis:")
for label, timing in timings.items():
print(
f" {label}: started at {timing['start']:.4f}, ended at {timing['end']:.4f}, duration {timing['duration']:.4f}s"
)
# Check if operations overlapped (true concurrency) or were serialized
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
print("\nExecution order by start time:")
for i, (label, timing) in enumerate(sorted_timings):
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
# Check for overlap (true concurrency) vs serialization
overlaps = []
for i in range(len(sorted_timings) - 1):
current = sorted_timings[i]
next_op = sorted_timings[i + 1]
if current[1]["end"] > next_op[1]["start"]:
overlaps.append(f"{current[0]} overlaps with {next_op[0]}")
if overlaps:
print(f"✅ TRUE CONCURRENCY detected: {overlaps}")
else:
print("🔒 SERIALIZATION detected: No overlapping execution times")
# Check final balance
final_balance = await credit_system.get_credits(user_id)
print(f"Final balance: {final_balance}")
# Count successes/failures
successful = [r for r in results if "SUCCESS" in str(r)]
failed = [r for r in results if "FAILED" in str(r)]
print(f"Successful: {len(successful)}, Failed: {len(failed)}")
# All should succeed since 150 - (10 + 20 + 30) = 90 > 0
assert (
len(successful) == 3
), f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
assert final_balance == 90, f"Expected balance 90, got {final_balance}"
# Check transaction timestamps to confirm database-level serialization
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE},
order={"createdAt": "asc"},
)
print("\nDatabase transaction order (by createdAt):")
for i, tx in enumerate(transactions):
print(
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
)
# Verify running balances are chronologically consistent (ordered by createdAt)
actual_balances = [
tx.runningBalance for tx in transactions if tx.runningBalance is not None
]
print(f"Running balances: {actual_balances}")
# The balances should be valid intermediate states regardless of execution order
# Starting balance: 150, spending 10+20+30=60, so final should be 90
# The intermediate balances depend on execution order but should all be valid
expected_possible_balances = {
# If order is 10, 20, 30: [140, 120, 90]
# If order is 10, 30, 20: [140, 110, 90]
# If order is 20, 10, 30: [130, 120, 90]
# If order is 20, 30, 10: [130, 100, 90]
# If order is 30, 10, 20: [120, 110, 90]
# If order is 30, 20, 10: [120, 100, 90]
90,
100,
110,
120,
130,
140, # All possible intermediate balances
}
# Verify all balances are valid intermediate states
for balance in actual_balances:
assert (
balance in expected_possible_balances
), f"Invalid balance {balance}, expected one of {expected_possible_balances}"
# Final balance should always be 90 (150 - 60)
assert (
min(actual_balances) == 90
), f"Final balance should be 90, got {min(actual_balances)}"
# The final transaction should always have balance 90
# The other transactions should have valid intermediate balances
assert (
90 in actual_balances
), f"Final balance 90 should be in actual_balances: {actual_balances}"
# All balances should be >= 90 (the final state)
assert all(
balance >= 90 for balance in actual_balances
), f"All balances should be >= 90, got {actual_balances}"
# CRITICAL: Transactions are atomic but can complete in any order
# What matters is that all running balances are valid intermediate states
# Each balance should be between 90 (final) and 140 (after first transaction)
for balance in actual_balances:
assert (
90 <= balance <= 140
), f"Balance {balance} is outside valid range [90, 140]"
# Final balance (minimum) should always be 90
assert (
min(actual_balances) == 90
), f"Final balance should be 90, got {min(actual_balances)}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_prove_database_locking_behavior(server: SpinTestServer):
"""Definitively prove whether database locking causes waiting vs failures."""
user_id = f"locking-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set balance to exact amount that can handle all spends using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=60, # Exactly 10+20+30
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "exact_amount_test"}),
)
async def spend_with_precise_timing(amount: int, label: str):
request_start = asyncio.get_event_loop().time()
db_operation_start = asyncio.get_event_loop().time()
try:
# Add a small delay to increase chance of true concurrency
await asyncio.sleep(0.001)
db_operation_start = asyncio.get_event_loop().time()
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(
graph_exec_id=f"locking-{label}",
reason=f"Locking test {label}",
),
)
db_operation_end = asyncio.get_event_loop().time()
return {
"label": label,
"status": "SUCCESS",
"request_start": request_start,
"db_start": db_operation_start,
"db_end": db_operation_end,
"db_duration": db_operation_end - db_operation_start,
}
except Exception as e:
db_operation_end = asyncio.get_event_loop().time()
return {
"label": label,
"status": "FAILED",
"error": str(e),
"request_start": request_start,
"db_start": db_operation_start,
"db_end": db_operation_end,
"db_duration": db_operation_end - db_operation_start,
}
# Launch all requests simultaneously
results = await asyncio.gather(
spend_with_precise_timing(10, "A"),
spend_with_precise_timing(20, "B"),
spend_with_precise_timing(30, "C"),
return_exceptions=True,
)
print("\n🔍 LOCKING BEHAVIOR ANALYSIS:")
print("=" * 50)
successful = [
r for r in results if isinstance(r, dict) and r.get("status") == "SUCCESS"
]
failed = [
r for r in results if isinstance(r, dict) and r.get("status") == "FAILED"
]
print(f"✅ Successful operations: {len(successful)}")
print(f"❌ Failed operations: {len(failed)}")
if len(failed) > 0:
print(
"\n🚫 CONCURRENT FAILURES - Some requests failed due to insufficient balance:"
)
for result in failed:
if isinstance(result, dict):
print(
f" {result['label']}: {result.get('error', 'Unknown error')}"
)
if len(successful) == 3:
print(
"\n🔒 SERIALIZATION CONFIRMED - All requests succeeded, indicating they were queued:"
)
# Sort by actual execution time to see order
dict_results = [r for r in results if isinstance(r, dict)]
sorted_results = sorted(dict_results, key=lambda x: x["db_start"])
for i, result in enumerate(sorted_results):
print(
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
)
# Check if any operations overlapped at the database level
print("\n⏱️ Database operation timeline:")
for result in sorted_results:
print(
f" {result['label']}: {result['db_start']:.4f} -> {result['db_end']:.4f}"
)
# Verify final state
final_balance = await credit_system.get_credits(user_id)
print(f"\n💰 Final balance: {final_balance}")
if len(successful) == 3:
assert (
final_balance == 0
), f"If all succeeded, balance should be 0, got {final_balance}"
print(
"✅ CONCLUSION: Database row locking causes requests to WAIT and execute serially"
)
else:
print(
"❌ CONCLUSION: Some requests failed, indicating different concurrency behavior"
)
finally:
await cleanup_test_user(user_id)

View File

@@ -0,0 +1,277 @@
"""
Integration tests for credit system to catch SQL enum casting issues.
These tests run actual database operations to ensure SQL queries work correctly,
which would have caught the CreditTransactionType enum casting bug.
"""
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import (
AutoTopUpConfig,
BetaUserCredit,
UsageTransactionMetadata,
get_auto_top_up,
set_auto_top_up,
)
from backend.util.json import SafeJson
@pytest.fixture
async def cleanup_test_user():
"""Clean up test user data before and after tests."""
import uuid
user_id = str(uuid.uuid4()) # Use unique user ID for each test
# Create the user first
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"topUpConfig": SafeJson({}),
"timezone": "UTC",
}
)
except Exception:
# User might already exist, that's fine
pass
yield user_id
# Cleanup after test
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
# Clear auto-top-up config before deleting user
await User.prisma().update(
where={"id": user_id}, data={"topUpConfig": SafeJson({})}
)
await User.prisma().delete(where={"id": user_id})
@pytest.mark.asyncio(loop_scope="session")
async def test_credit_transaction_enum_casting_integration(cleanup_test_user):
"""
Integration test to verify CreditTransactionType enum casting works in SQL queries.
This test would have caught the enum casting bug where PostgreSQL expected
platform."CreditTransactionType" but got "CreditTransactionType".
"""
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# Test each transaction type to ensure enum casting works
test_cases = [
(CreditTransactionType.TOP_UP, 100, "Test top-up"),
(CreditTransactionType.USAGE, -50, "Test usage"),
(CreditTransactionType.GRANT, 200, "Test grant"),
(CreditTransactionType.REFUND, -25, "Test refund"),
(CreditTransactionType.CARD_CHECK, 0, "Test card check"),
]
for transaction_type, amount, reason in test_cases:
metadata = SafeJson({"reason": reason, "test": "enum_casting"})
# This call would fail with enum casting error before the fix
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=amount,
transaction_type=transaction_type,
metadata=metadata,
is_active=True,
)
# Verify transaction was created with correct type
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.type == transaction_type
assert transaction.amount == amount
assert transaction.metadata is not None
# Verify metadata content
assert transaction.metadata["reason"] == reason
assert transaction.metadata["test"] == "enum_casting"
@pytest.mark.asyncio(loop_scope="session")
async def test_auto_top_up_integration(cleanup_test_user, monkeypatch):
"""
Integration test for auto-top-up functionality that triggers enum casting.
This tests the complete auto-top-up flow which involves SQL queries with
CreditTransactionType enums, ensuring enum casting works end-to-end.
"""
# Enable credits for this test
from backend.data.credit import settings
monkeypatch.setattr(settings.config, "enable_credit", True)
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# First add some initial credits so we can test the configuration and subsequent behavior
balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=50, # Below threshold that we'll set
transaction_type=CreditTransactionType.GRANT,
metadata=SafeJson({"reason": "Initial credits before auto top-up config"}),
)
assert balance == 50
# Configure auto top-up with threshold above current balance
config = AutoTopUpConfig(threshold=100, amount=500)
await set_auto_top_up(user_id, config)
# Verify configuration was saved but no immediate top-up occurred
current_balance = await credit_system.get_credits(user_id)
assert current_balance == 50 # Balance should be unchanged
# Simulate spending credits that would trigger auto top-up
# This involves multiple SQL operations with enum casting
try:
metadata = UsageTransactionMetadata(reason="Test spend to trigger auto top-up")
await credit_system.spend_credits(user_id=user_id, cost=10, metadata=metadata)
# The auto top-up mechanism should have been triggered
# Verify the transaction types were handled correctly
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id}, order={"createdAt": "desc"}
)
# Should have at least: GRANT (initial), USAGE (spend), and TOP_UP (auto top-up)
assert len(transactions) >= 3
# Verify different transaction types exist and enum casting worked
transaction_types = {t.type for t in transactions}
assert CreditTransactionType.GRANT in transaction_types
assert CreditTransactionType.USAGE in transaction_types
assert (
CreditTransactionType.TOP_UP in transaction_types
) # Auto top-up should have triggered
except Exception as e:
# If this fails with enum casting error, the test successfully caught the bug
if "CreditTransactionType" in str(e) and (
"cast" in str(e).lower() or "type" in str(e).lower()
):
pytest.fail(f"Enum casting error detected: {e}")
else:
# Re-raise other unexpected errors
raise
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_enum_casting_integration(cleanup_test_user):
"""
Integration test for _enable_transaction with enum casting.
Tests the scenario where inactive transactions are enabled, which also
involves SQL queries with CreditTransactionType enum casting.
"""
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# Create an inactive transaction
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=100,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"reason": "Inactive transaction test"}),
is_active=False, # Create as inactive
)
# Balance should be 0 since transaction is inactive
assert balance == 0
# Enable the transaction with new metadata
enable_metadata = SafeJson(
{
"payment_method": "test_payment",
"activation_reason": "Integration test activation",
}
)
# This would fail with enum casting error before the fix
final_balance = await credit_system._enable_transaction(
transaction_key=tx_key,
user_id=user_id,
metadata=enable_metadata,
)
# Now balance should reflect the activated transaction
assert final_balance == 100
# Verify transaction was properly enabled with correct enum type
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.isActive is True
assert transaction.type == CreditTransactionType.TOP_UP
assert transaction.runningBalance == 100
# Verify metadata was updated
assert transaction.metadata is not None
assert transaction.metadata["payment_method"] == "test_payment"
assert transaction.metadata["activation_reason"] == "Integration test activation"
@pytest.mark.asyncio(loop_scope="session")
async def test_auto_top_up_configuration_storage(cleanup_test_user, monkeypatch):
"""
Test that auto-top-up configuration is properly stored and retrieved.
The immediate top-up logic is handled by the API routes, not the core
set_auto_top_up function. This test verifies the configuration is correctly
saved and can be retrieved.
"""
# Enable credits for this test
from backend.data.credit import settings
monkeypatch.setattr(settings.config, "enable_credit", True)
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# Set initial balance
balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=50,
transaction_type=CreditTransactionType.GRANT,
metadata=SafeJson({"reason": "Initial balance for config test"}),
)
assert balance == 50
# Configure auto top-up
config = AutoTopUpConfig(threshold=100, amount=200)
await set_auto_top_up(user_id, config)
# Verify the configuration was saved
retrieved_config = await get_auto_top_up(user_id)
assert retrieved_config.threshold == config.threshold
assert retrieved_config.amount == config.amount
# Verify balance is unchanged (no immediate top-up from set_auto_top_up)
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 50 # Should be unchanged
# Verify no immediate auto-top-up transaction was created by set_auto_top_up
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id}, order={"createdAt": "desc"}
)
# Should only have the initial GRANT transaction
assert len(transactions) == 1
assert transactions[0].type == CreditTransactionType.GRANT

View File

@@ -0,0 +1,141 @@
"""
Tests for credit system metadata handling to ensure JSON casting works correctly.
This test verifies that metadata parameters are properly serialized when passed
to raw SQL queries with JSONB columns.
"""
# type: ignore
from typing import Any
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance
from backend.data.credit import BetaUserCredit
from backend.data.user import DEFAULT_USER_ID
from backend.util.json import SafeJson
@pytest.fixture
async def setup_test_user():
"""Setup test user and cleanup after test."""
user_id = DEFAULT_USER_ID
# Cleanup before test
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
yield user_id
# Cleanup after test
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
@pytest.mark.asyncio(loop_scope="session")
async def test_metadata_json_serialization(setup_test_user):
"""Test that metadata is properly serialized for JSONB column in raw SQL."""
user_id = setup_test_user
credit_system = BetaUserCredit(1000)
# Test with complex metadata that would fail if not properly serialized
complex_metadata = SafeJson(
{
"graph_exec_id": "test-12345",
"reason": "Testing metadata serialization",
"nested_data": {
"key1": "value1",
"key2": ["array", "of", "values"],
"key3": {"deeply": {"nested": "object"}},
},
"special_chars": "Testing 'quotes' and \"double quotes\" and unicode: 🚀",
}
)
# This should work without throwing a JSONB casting error
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=500, # $5 top-up
transaction_type=CreditTransactionType.TOP_UP,
metadata=complex_metadata,
is_active=True,
)
# Verify the transaction was created successfully
assert balance == 500
# Verify the metadata was stored correctly in the database
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.metadata is not None
# Verify the metadata contains our complex data
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
assert metadata_dict["graph_exec_id"] == "test-12345"
assert metadata_dict["reason"] == "Testing metadata serialization"
assert metadata_dict["nested_data"]["key1"] == "value1"
assert metadata_dict["nested_data"]["key3"]["deeply"]["nested"] == "object"
assert (
metadata_dict["special_chars"]
== "Testing 'quotes' and \"double quotes\" and unicode: 🚀"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_metadata_serialization(setup_test_user):
"""Test that _enable_transaction also handles metadata JSON serialization correctly."""
user_id = setup_test_user
credit_system = BetaUserCredit(1000)
# First create an inactive transaction
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=300,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"initial": "inactive_transaction"}),
is_active=False, # Create as inactive
)
# Initial balance should be 0 because transaction is inactive
assert balance == 0
# Now enable the transaction with new metadata
enable_metadata = SafeJson(
{
"payment_method": "stripe",
"payment_intent": "pi_test_12345",
"activation_reason": "Payment confirmed",
"complex_data": {"array": [1, 2, 3], "boolean": True, "null_value": None},
}
)
# This should work without JSONB casting errors
final_balance = await credit_system._enable_transaction(
transaction_key=tx_key,
user_id=user_id,
metadata=enable_metadata,
)
# Now balance should reflect the activated transaction
assert final_balance == 300
# Verify the metadata was updated correctly
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.isActive is True
# Verify the metadata was updated with enable_metadata
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
assert metadata_dict["payment_method"] == "stripe"
assert metadata_dict["payment_intent"] == "pi_test_12345"
assert metadata_dict["complex_data"]["array"] == [1, 2, 3]
assert metadata_dict["complex_data"]["boolean"] is True
assert metadata_dict["complex_data"]["null_value"] is None

View File

@@ -0,0 +1,372 @@
"""
Tests for credit system refund and dispute operations.
These tests ensure that refund operations (deduct_credits, handle_dispute)
are atomic and maintain data consistency.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
import stripe
from prisma.enums import CreditTransactionType
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
credit_system = UserCredit()
# Test user ID for refund tests
REFUND_TEST_USER_ID = "refund-test-user"
async def setup_test_user_with_topup():
"""Create a test user with initial balance and a top-up transaction."""
# Clean up any existing data
await CreditRefundRequest.prisma().delete_many(
where={"userId": REFUND_TEST_USER_ID}
)
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
# Create user
await User.prisma().create(
data={
"id": REFUND_TEST_USER_ID,
"email": f"{REFUND_TEST_USER_ID}@example.com",
"name": "Refund Test User",
}
)
# Create user balance
await UserBalance.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"balance": 1000, # $10
}
)
# Create a top-up transaction that can be refunded
topup_tx = await CreditTransaction.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 1000,
"type": CreditTransactionType.TOP_UP,
"transactionKey": "pi_test_12345",
"runningBalance": 1000,
"isActive": True,
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
}
)
return topup_tx
async def cleanup_test_user():
"""Clean up test data."""
await CreditRefundRequest.prisma().delete_many(
where={"userId": REFUND_TEST_USER_ID}
)
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
@pytest.mark.asyncio(loop_scope="session")
async def test_deduct_credits_atomic(server: SpinTestServer):
"""Test that deduct_credits is atomic and creates transaction correctly."""
topup_tx = await setup_test_user_with_topup()
try:
# Create a mock refund object
refund = MagicMock(spec=stripe.Refund)
refund.id = "re_test_refund_123"
refund.payment_intent = topup_tx.transactionKey
refund.amount = 500 # Refund $5 of the $10 top-up
refund.status = "succeeded"
refund.reason = "requested_by_customer"
refund.created = int(datetime.now(timezone.utc).timestamp())
# Create refund request record (simulating webhook flow)
await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 500,
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
"reason": "Test refund",
}
)
# Call deduct_credits
await credit_system.deduct_credits(refund)
# Verify the user's balance was deducted
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert (
user_balance.balance == 500
), f"Expected balance 500, got {user_balance.balance}"
# Verify refund transaction was created
refund_tx = await CreditTransaction.prisma().find_first(
where={
"userId": REFUND_TEST_USER_ID,
"type": CreditTransactionType.REFUND,
"transactionKey": refund.id,
}
)
assert refund_tx is not None
assert refund_tx.amount == -500
assert refund_tx.runningBalance == 500
assert refund_tx.isActive
# Verify refund request was updated
refund_request = await CreditRefundRequest.prisma().find_first(
where={
"userId": REFUND_TEST_USER_ID,
"transactionKey": topup_tx.transactionKey,
}
)
assert refund_request is not None
assert (
refund_request.result
== "The refund request has been approved, the amount will be credited back to your account."
)
finally:
await cleanup_test_user()
@pytest.mark.asyncio(loop_scope="session")
async def test_deduct_credits_user_not_found(server: SpinTestServer):
"""Test that deduct_credits raises error if transaction not found (which means user doesn't exist)."""
# Create a mock refund object that references a non-existent payment intent
refund = MagicMock(spec=stripe.Refund)
refund.id = "re_test_refund_nonexistent"
refund.payment_intent = "pi_test_nonexistent" # This payment intent doesn't exist
refund.amount = 500
refund.status = "succeeded"
refund.reason = "requested_by_customer"
refund.created = int(datetime.now(timezone.utc).timestamp())
# Should raise error for missing transaction
with pytest.raises(Exception): # Should raise NotFoundError for missing transaction
await credit_system.deduct_credits(refund)
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.data.credit.settings")
@patch("stripe.Dispute.modify")
@patch("backend.data.credit.get_user_by_id")
async def test_handle_dispute_with_sufficient_balance(
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
):
"""Test handling dispute when user has sufficient balance (dispute gets closed)."""
topup_tx = await setup_test_user_with_topup()
try:
# Mock settings to have a low tolerance threshold
mock_settings.config.refund_credit_tolerance_threshold = 0
# Mock the user lookup
mock_user = MagicMock()
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
mock_get_user.return_value = mock_user
# Create a mock dispute object for small amount (user has 1000, disputing 100)
dispute = MagicMock(spec=stripe.Dispute)
dispute.id = "dp_test_dispute_123"
dispute.payment_intent = topup_tx.transactionKey
dispute.amount = 100 # Small dispute amount
dispute.status = "pending"
dispute.reason = "fraudulent"
dispute.created = int(datetime.now(timezone.utc).timestamp())
# Mock the close method to prevent real API calls
dispute.close = MagicMock()
# Handle the dispute
await credit_system.handle_dispute(dispute)
# Verify dispute.close() was called (since user has sufficient balance)
dispute.close.assert_called_once()
# Verify no stripe evidence was added since dispute was closed
mock_stripe_modify.assert_not_called()
# Verify the user's balance was NOT deducted (dispute was closed)
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert (
user_balance.balance == 1000
), f"Balance should remain 1000, got {user_balance.balance}"
finally:
await cleanup_test_user()
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.data.credit.settings")
@patch("stripe.Dispute.modify")
@patch("backend.data.credit.get_user_by_id")
async def test_handle_dispute_with_insufficient_balance(
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
):
"""Test handling dispute when user has insufficient balance (evidence gets added)."""
topup_tx = await setup_test_user_with_topup()
# Save original method for restoration before any try blocks
original_get_history = credit_system.get_transaction_history
try:
# Mock settings to have a high tolerance threshold so dispute isn't closed
mock_settings.config.refund_credit_tolerance_threshold = 2000
# Mock the user lookup
mock_user = MagicMock()
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
mock_get_user.return_value = mock_user
# Mock the transaction history method to return an async result
from unittest.mock import AsyncMock
mock_history = MagicMock()
mock_history.transactions = []
credit_system.get_transaction_history = AsyncMock(return_value=mock_history)
# Create a mock dispute object for full amount (user has 1000, disputing 1000)
dispute = MagicMock(spec=stripe.Dispute)
dispute.id = "dp_test_dispute_pending"
dispute.payment_intent = topup_tx.transactionKey
dispute.amount = 1000
dispute.status = "warning_needs_response"
dispute.created = int(datetime.now(timezone.utc).timestamp())
# Mock the close method to prevent real API calls
dispute.close = MagicMock()
# Handle the dispute (evidence should be added)
await credit_system.handle_dispute(dispute)
# Verify dispute.close() was NOT called (insufficient balance after tolerance)
dispute.close.assert_not_called()
# Verify stripe evidence was added since dispute wasn't closed
mock_stripe_modify.assert_called_once()
# Verify the user's balance was NOT deducted (handle_dispute doesn't deduct credits)
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert user_balance.balance == 1000, "Balance should remain unchanged"
finally:
credit_system.get_transaction_history = original_get_history
await cleanup_test_user()
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_refunds(server: SpinTestServer):
"""Test that concurrent refunds are handled atomically."""
import asyncio
topup_tx = await setup_test_user_with_topup()
try:
# Create multiple refund requests
refund_requests = []
for i in range(5):
req = await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 100, # $1 each
"transactionKey": topup_tx.transactionKey,
"reason": f"Test refund {i}",
}
)
refund_requests.append(req)
# Create refund tasks to run concurrently
async def process_refund(index: int):
refund = MagicMock(spec=stripe.Refund)
refund.id = f"re_test_concurrent_{index}"
refund.payment_intent = topup_tx.transactionKey
refund.amount = 100 # $1 refund
refund.status = "succeeded"
refund.reason = "requested_by_customer"
refund.created = int(datetime.now(timezone.utc).timestamp())
try:
await credit_system.deduct_credits(refund)
return "success"
except Exception as e:
return f"error: {e}"
# Run refunds concurrently
results = await asyncio.gather(
*[process_refund(i) for i in range(5)], return_exceptions=True
)
# All should succeed
assert all(r == "success" for r in results), f"Some refunds failed: {results}"
# Verify final balance - with non-atomic implementation, this will demonstrate race condition
# EXPECTED BEHAVIOR: Due to race conditions, not all refunds will be properly processed
# The balance will be incorrect (higher than expected) showing lost updates
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
# With atomic implementation, this should be 500 (1000 - 5*100)
# With current non-atomic implementation, this will likely be wrong due to race conditions
print(f"DEBUG: Final balance = {user_balance.balance}, expected = 500")
# With atomic implementation, all 5 refunds should process correctly
assert (
user_balance.balance == 500
), f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
# Verify all refund transactions exist
refund_txs = await CreditTransaction.prisma().find_many(
where={
"userId": REFUND_TEST_USER_ID,
"type": CreditTransactionType.REFUND,
}
)
assert (
len(refund_txs) == 5
), f"Expected 5 refund transactions, got {len(refund_txs)}"
running_balances: set[int] = {
tx.runningBalance for tx in refund_txs if tx.runningBalance is not None
}
# Verify all balances are valid intermediate states
for balance in running_balances:
assert (
500 <= balance <= 1000
), f"Invalid balance {balance}, should be between 500 and 1000"
# Final balance should be present
assert (
500 in running_balances
), f"Final balance 500 should be in {running_balances}"
# All balances should be unique and form a valid sequence
sorted_balances = sorted(running_balances, reverse=True)
assert (
len(sorted_balances) == 5
), f"Expected 5 unique balances, got {len(sorted_balances)}"
finally:
await cleanup_test_user()

View File

@@ -1,8 +1,8 @@
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction
from prisma.models import CreditTransaction, UserBalance
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
@@ -19,14 +19,24 @@ user_credit = BetaUserCredit(REFILL_VALUE)
async def disable_test_user_transactions():
await CreditTransaction.prisma().delete_many(where={"userId": DEFAULT_USER_ID})
# Also reset the balance to 0 and set updatedAt to old date to trigger monthly refill
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
"update": {"balance": 0, "updatedAt": old_date},
},
)
async def top_up(amount: int):
await user_credit._add_transaction(
balance, _ = await user_credit._add_transaction(
DEFAULT_USER_ID,
amount,
CreditTransactionType.TOP_UP,
)
return balance
async def spend_credits(entry: NodeExecutionEntry) -> int:
@@ -111,29 +121,90 @@ async def test_block_credit_top_up(server: SpinTestServer):
@pytest.mark.asyncio(loop_scope="session")
async def test_block_credit_reset(server: SpinTestServer):
"""Test that BetaUserCredit provides monthly refills correctly."""
await disable_test_user_transactions()
month1 = 1
month2 = 2
# set the calendar to month 2 but use current time from now
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
month=month2, day=1
)
month2credit = await user_credit.get_credits(DEFAULT_USER_ID)
# Save original time_now function for restoration
original_time_now = user_credit.time_now
# Month 1 result should only affect month 1
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
month=month1, day=1
)
month1credit = await user_credit.get_credits(DEFAULT_USER_ID)
await top_up(100)
assert await user_credit.get_credits(DEFAULT_USER_ID) == month1credit + 100
try:
# Test month 1 behavior
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
user_credit.time_now = lambda: month1
# Month 2 balance is unaffected
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
month=month2, day=1
)
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
# First call in month 1 should trigger refill
balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert balance == REFILL_VALUE # Should get 1000 credits
# Manually create a transaction with month 1 timestamp to establish history
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": 100,
"type": CreditTransactionType.TOP_UP,
"runningBalance": 1100,
"isActive": True,
"createdAt": month1, # Set specific timestamp
}
)
# Update user balance to match
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
"update": {"balance": 1100},
},
)
# Now test month 2 behavior
month2 = datetime.now(timezone.utc).replace(month=2, day=1)
user_credit.time_now = lambda: month2
# In month 2, since balance (1100) > refill (1000), no refill should happen
month2_balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert month2_balance == 1100 # Balance persists, no reset
# Now test the refill behavior when balance is low
# Set balance below refill threshold
await UserBalance.prisma().update(
where={"userId": DEFAULT_USER_ID}, data={"balance": 400}
)
# Create a month 2 transaction to update the last transaction time
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": -700, # Spent 700 to get to 400
"type": CreditTransactionType.USAGE,
"runningBalance": 400,
"isActive": True,
"createdAt": month2,
}
)
# Move to month 3
month3 = datetime.now(timezone.utc).replace(month=3, day=1)
user_credit.time_now = lambda: month3
# Should get refilled since balance (400) < refill value (1000)
month3_balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert month3_balance == REFILL_VALUE # Should be refilled to 1000
# Verify the refill transaction was created
refill_tx = await CreditTransaction.prisma().find_first(
where={
"userId": DEFAULT_USER_ID,
"type": CreditTransactionType.GRANT,
"transactionKey": {"contains": "MONTHLY-CREDIT-TOP-UP"},
},
order={"createdAt": "desc"},
)
assert refill_tx is not None, "Monthly refill transaction should be created"
assert refill_tx.amount == 600, "Refill should be 600 (1000 - 400)"
finally:
# Restore original time_now function
user_credit.time_now = original_time_now
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -0,0 +1,361 @@
"""
Test underflow protection for cumulative refunds and negative transactions.
This test ensures that when multiple large refunds are processed, the user balance
doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound issues.
"""
import asyncio
from uuid import uuid4
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
from backend.util.test import SpinTestServer
async def create_test_user(user_id: str) -> None:
"""Create a test user for underflow tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their transactions."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_debug_underflow_step_by_step(server: SpinTestServer):
"""Debug underflow behavior step by step."""
credit_system = UserCredit()
user_id = f"debug-underflow-{uuid4()}"
await create_test_user(user_id)
try:
print(f"POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
# Test 1: Set up balance close to underflow threshold
print("\n=== Test 1: Setting up balance close to underflow threshold ===")
# First, manually set balance to a value very close to POSTGRES_INT_MIN
# We'll set it to POSTGRES_INT_MIN + 100, then try to subtract 200
# This should trigger underflow protection: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
initial_balance_target = POSTGRES_INT_MIN + 100
# Use direct database update to set the balance close to underflow
from prisma.models import UserBalance
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance_target},
"update": {"balance": initial_balance_target},
},
)
current_balance = await credit_system.get_credits(user_id)
print(f"Set balance to: {current_balance}")
assert current_balance == initial_balance_target
# Test 2: Apply amount that should cause underflow
print("\n=== Test 2: Testing underflow protection ===")
test_amount = (
-200
) # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
expected_without_protection = current_balance + test_amount
print(f"Current balance: {current_balance}")
print(f"Test amount: {test_amount}")
print(f"Without protection would be: {expected_without_protection}")
print(f"Should be clamped to POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
# Apply the amount that should trigger underflow protection
balance_result, _ = await credit_system._add_transaction(
user_id=user_id,
amount=test_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
print(f"Actual result: {balance_result}")
# Check if underflow protection worked
assert (
balance_result == POSTGRES_INT_MIN
), f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
# Test 3: Edge case - exactly at POSTGRES_INT_MIN
print("\n=== Test 3: Testing exact POSTGRES_INT_MIN boundary ===")
# Set balance to exactly POSTGRES_INT_MIN
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
"update": {"balance": POSTGRES_INT_MIN},
},
)
edge_balance = await credit_system.get_credits(user_id)
print(f"Balance set to exactly POSTGRES_INT_MIN: {edge_balance}")
# Try to subtract 1 - should stay at POSTGRES_INT_MIN
edge_result, _ = await credit_system._add_transaction(
user_id=user_id,
amount=-1,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
print(f"After subtracting 1: {edge_result}")
assert (
edge_result == POSTGRES_INT_MIN
), f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_underflow_protection_large_refunds(server: SpinTestServer):
"""Test that large cumulative refunds don't cause integer underflow."""
credit_system = UserCredit()
user_id = f"underflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set up balance close to underflow threshold to test the protection
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
# This should trigger underflow protection
from prisma.models import UserBalance
test_balance = POSTGRES_INT_MIN + 1000
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": test_balance},
"update": {"balance": test_balance},
},
)
current_balance = await credit_system.get_credits(user_id)
assert current_balance == test_balance
# Try to deduct amount that would cause underflow: test_balance + (-2000) = POSTGRES_INT_MIN - 1000
underflow_amount = -2000
expected_without_protection = (
current_balance + underflow_amount
) # Should be POSTGRES_INT_MIN - 1000
# Use _add_transaction directly with amount that would cause underflow
final_balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=underflow_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False, # Allow going negative for refunds
)
# Balance should be clamped to POSTGRES_INT_MIN, not the calculated underflow value
assert (
final_balance == POSTGRES_INT_MIN
), f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
assert (
final_balance > expected_without_protection
), f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == POSTGRES_INT_MIN
), f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
# Verify transaction was created with the underflow-protected balance
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": CreditTransactionType.REFUND},
order={"createdAt": "desc"},
)
assert len(transactions) > 0, "Refund transaction should be created"
assert (
transactions[0].runningBalance == POSTGRES_INT_MIN
), f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServer):
"""Test that multiple large refunds applied sequentially don't cause underflow."""
credit_system = UserCredit()
user_id = f"cumulative-underflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
)
# Apply multiple refunds that would cumulatively underflow
refund_amount = -300 # Each refund that would cause underflow when cumulative
# First refund: (POSTGRES_INT_MIN + 500) + (-300) = POSTGRES_INT_MIN + 200 (still above minimum)
balance_1, _ = await credit_system._add_transaction(
user_id=user_id,
amount=refund_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
# Should be above minimum for first refund
expected_balance_1 = (
initial_balance + refund_amount
) # Should be POSTGRES_INT_MIN + 200
assert (
balance_1 == expected_balance_1
), f"First refund should result in {expected_balance_1}, got {balance_1}"
assert (
balance_1 >= POSTGRES_INT_MIN
), f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
# Second refund: (POSTGRES_INT_MIN + 200) + (-300) = POSTGRES_INT_MIN - 100 (would underflow)
balance_2, _ = await credit_system._add_transaction(
user_id=user_id,
amount=refund_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
# Should be clamped to minimum due to underflow protection
assert (
balance_2 == POSTGRES_INT_MIN
), f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
# Third refund: Should stay at minimum
balance_3, _ = await credit_system._add_transaction(
user_id=user_id,
amount=refund_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
# Should still be at minimum
assert (
balance_3 == POSTGRES_INT_MIN
), f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
# Final balance check
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == POSTGRES_INT_MIN
), f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
"""Test that concurrent large refunds don't cause race condition underflow."""
credit_system = UserCredit()
user_id = f"concurrent-underflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
)
async def large_refund(amount: int, label: str):
try:
return await credit_system._add_transaction(
user_id=user_id,
amount=-amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
except Exception as e:
return f"FAILED-{label}: {e}"
# Run concurrent refunds that would cause underflow if not protected
# Each refund of 500 would cause underflow: initial_balance + (-500) could go below POSTGRES_INT_MIN
refund_amount = 500
results = await asyncio.gather(
large_refund(refund_amount, "A"),
large_refund(refund_amount, "B"),
large_refund(refund_amount, "C"),
return_exceptions=True,
)
# Check all results are valid and no underflow occurred
valid_results = []
for i, result in enumerate(results):
if isinstance(result, tuple):
balance, _ = result
assert (
balance >= POSTGRES_INT_MIN
), f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
valid_results.append(balance)
elif isinstance(result, str) and "FAILED" in result:
# Some operations might fail due to validation, that's okay
pass
else:
# Unexpected exception
assert not isinstance(
result, Exception
), f"Unexpected exception in result {i}: {result}"
# At least one operation should succeed
assert (
len(valid_results) > 0
), f"At least one refund should succeed, got results: {results}"
# All successful results should be >= POSTGRES_INT_MIN
for balance in valid_results:
assert (
balance >= POSTGRES_INT_MIN
), f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
# Final balance should be valid and at or above POSTGRES_INT_MIN
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance >= POSTGRES_INT_MIN
), f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
finally:
await cleanup_test_user(user_id)

View File

@@ -0,0 +1,217 @@
"""
Integration test to verify complete migration from User.balance to UserBalance table.
This test ensures that:
1. No User.balance queries exist in the system
2. All balance operations go through UserBalance table
3. User and UserBalance stay synchronized properly
"""
import asyncio
from datetime import datetime
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import UsageTransactionMetadata, UserCredit
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
async def create_test_user(user_id: str) -> None:
"""Create a test user for migration tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their data."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_user_balance_migration_complete(server: SpinTestServer):
"""Test that User table balance is never used and UserBalance is source of truth."""
credit_system = UserCredit()
user_id = f"migration-test-{datetime.now().timestamp()}"
await create_test_user(user_id)
try:
# 1. Verify User table does NOT have balance set initially
user = await User.prisma().find_unique(where={"id": user_id})
assert user is not None
# User.balance should not exist or should be None/0 if it exists
user_balance_attr = getattr(user, "balance", None)
if user_balance_attr is not None:
assert (
user_balance_attr == 0 or user_balance_attr is None
), f"User.balance should be 0 or None, got {user_balance_attr}"
# 2. Perform various credit operations using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "migration_test"}),
)
balance1 = await credit_system.get_credits(user_id)
assert balance1 == 1000
await credit_system.spend_credits(
user_id,
300,
UsageTransactionMetadata(
graph_exec_id="test", reason="Migration test spend"
),
)
balance2 = await credit_system.get_credits(user_id)
assert balance2 == 700
# 3. Verify UserBalance table has correct values
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 700
), f"UserBalance should be 700, got {user_balance.balance}"
# 4. CRITICAL: Verify User.balance is NEVER updated during operations
user_after = await User.prisma().find_unique(where={"id": user_id})
assert user_after is not None
user_balance_after = getattr(user_after, "balance", None)
if user_balance_after is not None:
# If User.balance exists, it should still be 0 (never updated)
assert (
user_balance_after == 0 or user_balance_after is None
), f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
# 5. Verify get_credits always returns UserBalance value, not User.balance
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == user_balance.balance
), f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_detect_stale_user_balance_queries(server: SpinTestServer):
"""Test to detect if any operations are still using User.balance instead of UserBalance."""
credit_system = UserCredit()
user_id = f"stale-query-test-{datetime.now().timestamp()}"
await create_test_user(user_id)
try:
# Create UserBalance with specific value
await UserBalance.prisma().create(
data={"userId": user_id, "balance": 5000} # $50
)
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
balance = await credit_system.get_credits(user_id)
assert (
balance == 5000
), f"Expected get_credits to return 5000 from UserBalance, got {balance}"
# Verify all operations use UserBalance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "final_verification"}),
)
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 6000, f"Expected 6000, got {final_balance}"
# Verify UserBalance table has the correct value
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 6000
), f"UserBalance should be 6000, got {user_balance.balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer):
"""Test that concurrent operations all use UserBalance locking, not User.balance."""
credit_system = UserCredit()
user_id = f"concurrent-userbalance-test-{datetime.now().timestamp()}"
await create_test_user(user_id)
try:
# Set initial balance in UserBalance
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
# Run concurrent operations to ensure they all use UserBalance atomic operations
async def concurrent_spend(amount: int, label: str):
try:
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(
graph_exec_id=f"concurrent-{label}",
reason=f"Concurrent test {label}",
),
)
return f"{label}-SUCCESS"
except Exception as e:
return f"{label}-FAILED: {e}"
# Run concurrent operations
results = await asyncio.gather(
concurrent_spend(100, "A"),
concurrent_spend(200, "B"),
concurrent_spend(300, "C"),
return_exceptions=True,
)
# All should succeed (1000 >= 100+200+300)
successful = [r for r in results if "SUCCESS" in str(r)]
assert len(successful) == 3, f"All operations should succeed, got {results}"
# Final balance should be 1000 - 600 = 400
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 400, f"Expected final balance 400, got {final_balance}"
# Verify UserBalance has correct value
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 400
), f"UserBalance should be 400, got {user_balance.balance}"
# Critical: If User.balance exists and was used, it might have wrong value
try:
user = await User.prisma().find_unique(where={"id": user_id})
user_balance_attr = getattr(user, "balance", None)
if user_balance_attr is not None:
# If User.balance exists, it should NOT be used for operations
# The fact that our final balance is correct from UserBalance proves the system is working
print(
f"✅ User.balance exists ({user_balance_attr}) but UserBalance ({user_balance.balance}) is being used correctly"
)
except Exception:
print("✅ User.balance column doesn't exist - migration is complete")
finally:
await cleanup_test_user(user_id)

View File

@@ -98,42 +98,6 @@ async def transaction(timeout: int = TRANSACTION_TIMEOUT):
yield tx
@asynccontextmanager
async def locked_transaction(key: str, timeout: int = TRANSACTION_TIMEOUT):
"""
Create a transaction and take a per-key advisory *transaction* lock.
- Uses a 64-bit lock id via hashtextextended(key, 0) to avoid 32-bit collisions.
- Bound by lock_timeout and statement_timeout so it won't block indefinitely.
- Lock is held for the duration of the transaction and auto-released on commit/rollback.
Args:
key: String lock key (e.g., "usr_trx_<uuid>").
timeout: Transaction/lock/statement timeout in milliseconds.
"""
async with transaction(timeout=timeout) as tx:
# Ensure we don't wait longer than desired
# Note: SET LOCAL doesn't support parameterized queries, must use string interpolation
await tx.execute_raw(f"SET LOCAL statement_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
await tx.execute_raw(f"SET LOCAL lock_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
# Block until acquired or lock_timeout hits
try:
await tx.execute_raw(
"SELECT pg_advisory_xact_lock(hashtextextended($1, 0))",
key,
)
except Exception as e:
# Normalize PG's lock timeout error to TimeoutError for callers
if "lock timeout" in str(e).lower():
raise TimeoutError(
f"Could not acquire lock for key={key!r} within {timeout}ms"
) from e
raise
yield tx
def get_database_schema() -> str:
"""Extract database schema from DATABASE_URL."""
parsed_url = urlparse(DATABASE_URL)

View File

@@ -38,8 +38,8 @@ from prisma.types import (
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import type as type_utils
from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.retry import func_retry
@@ -478,6 +478,48 @@ async def get_graph_executions(
return [GraphExecutionMeta.from_db(execution) for execution in executions]
async def get_graph_executions_count(
user_id: Optional[str] = None,
graph_id: Optional[str] = None,
statuses: Optional[list[ExecutionStatus]] = None,
created_time_gte: Optional[datetime] = None,
created_time_lte: Optional[datetime] = None,
) -> int:
"""
Get count of graph executions with optional filters.
Args:
user_id: Optional user ID to filter by
graph_id: Optional graph ID to filter by
statuses: Optional list of execution statuses to filter by
created_time_gte: Optional minimum creation time
created_time_lte: Optional maximum creation time
Returns:
Count of matching graph executions
"""
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if user_id:
where_filter["userId"] = user_id
if graph_id:
where_filter["agentGraphId"] = graph_id
if created_time_gte or created_time_lte:
where_filter["createdAt"] = {
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
}
if statuses:
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
count = await AgentGraphExecution.prisma().count(where=where_filter)
return count
class GraphExecutionsPaginated(BaseModel):
"""Response schema for paginated graph executions."""

View File

@@ -7,7 +7,7 @@ from prisma.enums import AgentExecutionStatus
from backend.data.execution import get_graph_executions
from backend.data.graph import get_graph_metadata
from backend.data.model import UserExecutionSummaryStats
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.exceptions import DatabaseError
from backend.util.logging import TruncatedLogger
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[SummaryData]")

View File

@@ -129,17 +129,20 @@ class NodeModel(Node):
Returns a copy of the node model, stripped of any non-transferable properties
"""
stripped_node = self.model_copy(deep=True)
# Remove credentials from node input
# Remove credentials and other (possible) secrets from node input
if stripped_node.input_default:
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
stripped_node.input_default, self.block.input_schema.jsonschema()
)
# Remove default secret value from secret input nodes
if (
stripped_node.block.block_type == BlockType.INPUT
and stripped_node.input_default.get("secret", False) is True
and "value" in stripped_node.input_default
):
stripped_node.input_default["value"] = ""
del stripped_node.input_default["value"]
# Remove webhook info
stripped_node.webhook_id = None
@@ -156,8 +159,10 @@ class NodeModel(Node):
result = {}
for key, value in input_data.items():
field_schema: dict | None = field_schemas.get(key)
if (field_schema and field_schema.get("secret", False)) or any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
if (field_schema and field_schema.get("secret", False)) or (
any(sensitive_key in key.lower() for sensitive_key in sensitive_keys)
# Prevent removing `secret` flag on input nodes
and type(value) is not bool
):
# This is a secret value -> filter this key-value pair out
continue

View File

@@ -201,25 +201,56 @@ async def test_get_input_schema(server: SpinTestServer, snapshot: Snapshot):
@pytest.mark.asyncio(loop_scope="session")
async def test_clean_graph(server: SpinTestServer):
"""
Test the clean_graph function that:
1. Clears input block values
2. Removes credentials from nodes
Test the stripped_for_export function that:
1. Removes sensitive/secret fields from node inputs
2. Removes webhook information
3. Preserves non-sensitive data including input block values
"""
# Create a graph with input blocks and credentials
# Create a graph with input blocks containing both sensitive and normal data
graph = Graph(
id="test_clean_graph",
name="Test Clean Graph",
description="Test graph cleaning",
nodes=[
Node(
id="input_node",
block_id=AgentInputBlock().id,
input_default={
"_test_id": "input_node",
"name": "test_input",
"value": "test value",
"value": "test value", # This should be preserved
"description": "Test input description",
},
),
Node(
block_id=AgentInputBlock().id,
input_default={
"_test_id": "input_node_secret",
"name": "secret_input",
"value": "another value",
"secret": True, # This makes the input secret
},
),
Node(
block_id=StoreValueBlock().id,
input_default={
"_test_id": "node_with_secrets",
"input": "normal_value",
"control_test_input": "should be preserved",
"api_key": "secret_api_key_123", # Should be filtered
"password": "secret_password_456", # Should be filtered
"token": "secret_token_789", # Should be filtered
"credentials": { # Should be filtered
"id": "fake-github-credentials-id",
"provider": "github",
"type": "api_key",
},
"anthropic_credentials": { # Should be filtered
"id": "fake-anthropic-credentials-id",
"provider": "anthropic",
"type": "api_key",
},
},
),
],
links=[],
)
@@ -231,15 +262,54 @@ async def test_clean_graph(server: SpinTestServer):
)
# Clean the graph
created_graph = await server.agent_server.test_get_graph(
cleaned_graph = await server.agent_server.test_get_graph(
created_graph.id, created_graph.version, DEFAULT_USER_ID, for_export=True
)
# # Verify input block value is cleared
# Verify sensitive fields are removed but normal fields are preserved
input_node = next(
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
n for n in cleaned_graph.nodes if n.input_default["_test_id"] == "input_node"
)
assert input_node.input_default["value"] == ""
# Non-sensitive fields should be preserved
assert input_node.input_default["name"] == "test_input"
assert input_node.input_default["value"] == "test value" # Should be preserved now
assert input_node.input_default["description"] == "Test input description"
# Sensitive fields should be filtered out
assert "api_key" not in input_node.input_default
assert "password" not in input_node.input_default
# Verify secret input node preserves non-sensitive fields but removes secret value
secret_node = next(
n
for n in cleaned_graph.nodes
if n.input_default["_test_id"] == "input_node_secret"
)
assert secret_node.input_default["name"] == "secret_input"
assert "value" not in secret_node.input_default # Secret default should be removed
assert secret_node.input_default["secret"] is True
# Verify sensitive fields are filtered from nodes with secrets
secrets_node = next(
n
for n in cleaned_graph.nodes
if n.input_default["_test_id"] == "node_with_secrets"
)
# Normal fields should be preserved
assert secrets_node.input_default["input"] == "normal_value"
assert secrets_node.input_default["control_test_input"] == "should be preserved"
# Sensitive fields should be filtered out
assert "api_key" not in secrets_node.input_default
assert "password" not in secrets_node.input_default
assert "token" not in secrets_node.input_default
assert "credentials" not in secrets_node.input_default
assert "anthropic_credentials" not in secrets_node.input_default
# Verify webhook info is removed (if any nodes had it)
for node in cleaned_graph.nodes:
assert node.webhook_id is None
assert node.webhook is None
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -15,7 +15,7 @@ from prisma.types import (
# from backend.notifications.models import NotificationEvent
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
from backend.util.logging import TruncatedLogger
@@ -235,6 +235,7 @@ class BaseEventModel(BaseModel):
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
id: Optional[str] = None # None when creating, populated when reading from DB
data: NotificationDataType_co
@property
@@ -378,6 +379,7 @@ class NotificationPreference(BaseModel):
class UserNotificationEventDTO(BaseModel):
id: str # Added to track notifications for removal
type: NotificationType
data: dict
created_at: datetime
@@ -386,6 +388,7 @@ class UserNotificationEventDTO(BaseModel):
@staticmethod
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
return UserNotificationEventDTO(
id=model.id,
type=model.type,
data=dict(model.data),
created_at=model.createdAt,
@@ -541,6 +544,79 @@ async def empty_user_notification_batch(
) from e
async def clear_all_user_notification_batches(user_id: str) -> None:
"""Clear ALL notification batches for a user across all types.
Used when user's email is bounced/inactive and we should stop
trying to send them ANY emails.
"""
try:
# Delete all notification events for this user
await NotificationEvent.prisma().delete_many(
where={"UserNotificationBatch": {"is": {"userId": user_id}}}
)
# Delete all batches for this user
await UserNotificationBatch.prisma().delete_many(where={"userId": user_id})
logger.info(f"Cleared all notification batches for user {user_id}")
except Exception as e:
raise DatabaseError(
f"Failed to clear all notification batches for user {user_id}: {e}"
) from e
async def remove_notifications_from_batch(
user_id: str, notification_type: NotificationType, notification_ids: list[str]
) -> None:
"""Remove specific notifications from a user's batch by their IDs.
This is used after successful sending to remove only the
sent notifications, preventing duplicates on retry.
"""
if not notification_ids:
return
try:
# Delete the specific notification events
deleted_count = await NotificationEvent.prisma().delete_many(
where={
"id": {"in": notification_ids},
"UserNotificationBatch": {
"is": {"userId": user_id, "type": notification_type}
},
}
)
logger.info(
f"Removed {deleted_count} notifications from batch for user {user_id}"
)
# Check if batch is now empty and delete it if so
remaining = await NotificationEvent.prisma().count(
where={
"UserNotificationBatch": {
"is": {"userId": user_id, "type": notification_type}
}
}
)
if remaining == 0:
await UserNotificationBatch.prisma().delete_many(
where=UserNotificationBatchWhereInput(
userId=user_id,
type=notification_type,
)
)
logger.info(
f"Deleted empty batch for user {user_id} and type {notification_type}"
)
except Exception as e:
raise DatabaseError(
f"Failed to remove notifications from batch for user {user_id} and type {notification_type}: {e}"
) from e
async def get_user_notification_batch(
user_id: str,
notification_type: NotificationType,

View File

@@ -4,16 +4,15 @@ from typing import Any, Optional
import prisma
import pydantic
from autogpt_libs.utils.cache import cached
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.server.v2.store.model import StoreAgentDetails
from backend.util.cache import cached
from backend.util.json import SafeJson
# Mapping from user reason id to categories to search for when choosing agent to show
@@ -27,8 +26,6 @@ REASON_MAPPING: dict[str, list[str]] = {
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
user_credit = get_user_credit_model()
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
@@ -148,7 +145,8 @@ async def reward_user(user_id: str, step: OnboardingStep):
return
onboarding.rewardedFor.append(step)
await user_credit.onboarding_reward(user_id, reward, step)
user_credit_model = await get_user_credit_model(user_id)
await user_credit_model.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
@@ -278,8 +276,14 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
for word in user_onboarding.integrations
]
where_clause["is_available"] = True
# Try to take only agents that are available and allowed for onboarding
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where=prisma.types.StoreAgentWhereInput(**where_clause),
where={
"is_available": True,
"useForOnboarding": True,
},
order=[
{"featured": "desc"},
{"runs": "desc"},
@@ -288,59 +292,16 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
take=100,
)
agentListings = await prisma.models.StoreListingVersion.prisma().find_many(
where={
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
},
include={"AgentGraph": True},
)
for listing in agentListings:
agent = listing.AgentGraph
if agent is None:
continue
graph = GraphModel.from_db(agent)
# Remove agents with empty input schema
if not graph.input_schema:
storeAgents = [
a for a in storeAgents if a.storeListingVersionId != listing.id
]
continue
# Remove agents with empty credentials
# Get nodes from this agent that have credentials
nodes = await prisma.models.AgentNode.prisma().find_many(
where={
"agentGraphId": agent.id,
"agentBlockId": {"in": list(CREDENTIALS_FIELDS.keys())},
},
)
for node in nodes:
block_id = node.agentBlockId
field_name = CREDENTIALS_FIELDS[block_id]
# If there are no credentials or they are empty, remove the agent
# FIXME ignores default values
if (
field_name not in node.constantInput
or node.constantInput[field_name] is None
):
storeAgents = [
a for a in storeAgents if a.storeListingVersionId != listing.id
]
break
# If there are less than 2 agents, add more agents to the list
# If not enough agents found, relax the useForOnboarding filter
if len(storeAgents) < 2:
storeAgents += await prisma.models.StoreAgent.prisma().find_many(
where={
"listing_id": {"not_in": [agent.listing_id for agent in storeAgents]},
},
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
where=prisma.types.StoreAgentWhereInput(**where_clause),
order=[
{"featured": "desc"},
{"runs": "desc"},
{"rating": "desc"},
],
take=2 - len(storeAgents),
take=100,
)
# Calculate points for the first X agents and choose the top 2

View File

@@ -1,24 +1,29 @@
import logging
import os
from dotenv import load_dotenv
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from backend.util.cache import cached, thread_cached
from backend.util.retry import conn_retry
from backend.util.settings import Settings
settings = Settings()
load_dotenv()
HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", None)
logger = logging.getLogger(__name__)
@conn_retry("Redis", "Acquiring connection")
def connect(decode_responses: bool = True) -> Redis:
def connect() -> Redis:
c = Redis(
host=settings.config.redis_host,
port=settings.config.redis_port,
password=settings.config.redis_password or None,
decode_responses=decode_responses,
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=True,
)
c.ping()
return c
@@ -37,9 +42,9 @@ def get_redis() -> Redis:
@conn_retry("AsyncRedis", "Acquiring connection")
async def connect_async() -> AsyncRedis:
c = AsyncRedis(
host=settings.config.redis_host,
port=settings.config.redis_port,
password=settings.config.redis_password or None,
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=True,
)
await c.ping()

View File

@@ -15,9 +15,9 @@ from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
from backend.data.db import prisma
from backend.data.model import User, UserIntegrations, UserMetadata
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.cache import cached
from backend.util.encryption import JSONCryptor
from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson
from backend.util.settings import Settings
@@ -354,6 +354,36 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
) from e
async def disable_all_user_notifications(user_id: str) -> None:
"""Disable all notification preferences for a user.
Used when user's email bounces/is inactive to prevent any future notifications.
"""
try:
await PrismaUser.prisma().update(
where={"id": user_id},
data={
"notifyOnAgentRun": False,
"notifyOnZeroBalance": False,
"notifyOnLowBalance": False,
"notifyOnBlockExecutionFailed": False,
"notifyOnContinuousAgentError": False,
"notifyOnDailySummary": False,
"notifyOnWeeklySummary": False,
"notifyOnMonthlySummary": False,
"notifyOnAgentApproved": False,
"notifyOnAgentRejected": False,
},
)
# Invalidate cache for this user
get_user_by_id.cache_delete(user_id)
logger.info(f"Disabled all notification preferences for user {user_id}")
except Exception as e:
raise DatabaseError(
f"Failed to disable notifications for user {user_id}: {e}"
) from e
async def get_user_email_verification(user_id: str) -> bool:
"""Get the email verification status for a user."""
try:

View File

@@ -22,15 +22,13 @@ logger = logging.getLogger(__name__)
@pytest.fixture
def redis_client():
"""Get Redis client for testing using same config as backend."""
from backend.util.settings import Settings
settings = Settings()
from backend.data.redis_client import HOST, PASSWORD, PORT
# Use same config as backend but without decode_responses since ClusterLock needs raw bytes
client = redis.Redis(
host=settings.config.redis_host,
port=settings.config.redis_port,
password=settings.config.redis_password or None,
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=False, # ClusterLock needs raw bytes for ownership verification
)

View File

@@ -9,6 +9,7 @@ from backend.data.execution import (
get_execution_kv_data,
get_graph_execution_meta,
get_graph_executions,
get_graph_executions_count,
get_latest_node_execution,
get_node_execution,
get_node_executions,
@@ -28,11 +29,13 @@ from backend.data.graph import (
get_node,
)
from backend.data.notifications import (
clear_all_user_notification_batches,
create_or_add_to_user_notification_batch,
empty_user_notification_batch,
get_all_batches_by_type,
get_user_notification_batch,
get_user_notification_oldest_message_in_batch,
remove_notifications_from_batch,
)
from backend.data.user import (
get_active_user_ids_in_timerange,
@@ -54,7 +57,6 @@ from backend.util.service import (
from backend.util.settings import Config
config = Config()
_user_credit_model = get_user_credit_model()
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
@@ -63,15 +65,16 @@ R = TypeVar("R")
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
return await _user_credit_model.spend_credits(user_id, cost, metadata)
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.spend_credits(user_id, cost, metadata)
async def _get_credits(user_id: str) -> int:
return await _user_credit_model.get_credits(user_id)
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.get_credits(user_id)
class DatabaseManager(AppService):
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
@@ -111,6 +114,7 @@ class DatabaseManager(AppService):
# Executions
get_graph_executions = _(get_graph_executions)
get_graph_executions_count = _(get_graph_executions_count)
get_graph_execution_meta = _(get_graph_execution_meta)
create_graph_execution = _(create_graph_execution)
get_node_execution = _(get_node_execution)
@@ -147,10 +151,12 @@ class DatabaseManager(AppService):
get_user_notification_preference = _(get_user_notification_preference)
# Notifications - async
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
create_or_add_to_user_notification_batch = _(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = _(empty_user_notification_batch)
remove_notifications_from_batch = _(remove_notifications_from_batch)
get_all_batches_by_type = _(get_all_batches_by_type)
get_user_notification_batch = _(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
@@ -179,6 +185,7 @@ class DatabaseManagerClient(AppServiceClient):
# Executions
get_graph_executions = _(d.get_graph_executions)
get_graph_executions_count = _(d.get_graph_executions_count)
get_graph_execution_meta = _(d.get_graph_execution_meta)
get_node_executions = _(d.get_node_executions)
update_node_execution_status = _(d.update_node_execution_status)
@@ -241,10 +248,12 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_user_notification_preference = d.get_user_notification_preference
# Notifications
clear_all_user_notification_batches = d.clear_all_user_notification_batches
create_or_add_to_user_notification_batch = (
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = d.empty_user_notification_batch
remove_notifications_from_batch = d.remove_notifications_from_batch
get_all_batches_by_type = d.get_all_batches_by_type
get_user_notification_batch = d.get_user_notification_batch
get_user_notification_oldest_message_in_batch = (

View File

@@ -7,8 +7,10 @@ import uuid
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import sentry_sdk
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from prometheus_client import Gauge, start_http_server
@@ -84,7 +86,11 @@ from backend.util.file import clean_exec_files
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.metrics import DiscordChannel
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import continuous_retry, func_retry
from backend.util.retry import (
continuous_retry,
func_retry,
send_rate_limited_discord_alert,
)
from backend.util.settings import Settings
from .cluster_lock import ClusterLock
@@ -184,6 +190,7 @@ async def execute_node(
_input_data.inputs = input_data
if nodes_input_masks:
_input_data.nodes_input_masks = nodes_input_masks
_input_data.user_id = user_id
input_data = _input_data.model_dump()
data.inputs = input_data
@@ -218,14 +225,37 @@ async def execute_node(
extra_exec_kwargs[field_name] = credentials
output_size = 0
# sentry tracking nonsense to get user counts for blocks because isolation scopes don't work :(
scope = sentry_sdk.get_current_scope()
# save the tags
original_user = scope._user
original_tags = dict(scope._tags) if scope._tags else {}
# Set user ID for error tracking
scope.set_user({"id": user_id})
scope.set_tag("graph_id", graph_id)
scope.set_tag("node_id", node_id)
scope.set_tag("block_name", node_block.name)
scope.set_tag("block_id", node_block.id)
for k, v in (data.user_context or UserContext(timezone="UTC")).model_dump().items():
scope.set_tag(f"user_context.{k}", v)
try:
async for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
):
output_data = json.convert_pydantic_to_json(output_data)
output_data = json.to_dict(output_data)
output_size += len(json.dumps(output_data))
log_metadata.debug("Node produced output", **{output_name: output_data})
yield output_name, output_data
except Exception:
# Capture exception WITH context still set before restoring scope
sentry_sdk.capture_exception(scope=scope)
sentry_sdk.flush() # Ensure it's sent before we restore scope
# Re-raise to maintain normal error flow
raise
finally:
# Ensure credentials are released even if execution fails
if creds_lock and (await creds_lock.locked()) and (await creds_lock.owned()):
@@ -240,6 +270,10 @@ async def execute_node(
execution_stats.input_size = input_size
execution_stats.output_size = output_size
# Restore scope AFTER error has been captured
scope._user = original_user
scope._tags = original_tags
async def _enqueue_next_nodes(
db_client: "DatabaseManagerAsyncClient",
@@ -564,7 +598,6 @@ class ExecutionProcessor:
await persist_output(
"error", str(stats.error) or type(stats.error).__name__
)
return status
@func_retry
@@ -979,16 +1012,31 @@ class ExecutionProcessor:
if isinstance(e, Exception)
else Exception(f"{e.__class__.__name__}: {e}")
)
if not execution_stats.error:
execution_stats.error = str(error)
known_errors = (InsufficientBalanceError, ModerationError)
if isinstance(error, known_errors):
execution_stats.error = str(error)
return ExecutionStatus.FAILED
execution_status = ExecutionStatus.FAILED
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
# Send rate-limited Discord alert for unknown/unexpected errors
send_rate_limited_discord_alert(
"graph_execution",
error,
"unknown_error",
f"🚨 **Unknown Graph Execution Error**\n"
f"User: {graph_exec.user_id}\n"
f"Graph ID: {graph_exec.graph_id}\n"
f"Execution ID: {graph_exec.graph_exec_id}\n"
f"Error Type: {type(error).__name__}\n"
f"Error: {str(error)[:200]}{'...' if len(str(error)) > 200 else ''}\n",
)
raise
finally:
@@ -1163,9 +1211,9 @@ class ExecutionProcessor:
f"❌ **Insufficient Funds Alert**\n"
f"User: {user_email or user_id}\n"
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
f"Current balance: ${e.balance/100:.2f}\n"
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
f"Current balance: ${e.balance / 100:.2f}\n"
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
@@ -1212,9 +1260,9 @@ class ExecutionProcessor:
alert_message = (
f"⚠️ **Low Balance Alert**\n"
f"User: {user_email or user_id}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
f"Current balance: ${current_balance/100:.2f}\n"
f"Transaction cost: ${transaction_cost/100:.2f}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
f"Current balance: ${current_balance / 100:.2f}\n"
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
@@ -1445,10 +1493,39 @@ class ExecutionManager(AppProcess):
return
graph_exec_id = graph_exec_entry.graph_exec_id
user_id = graph_exec_entry.user_id
graph_id = graph_exec_entry.graph_id
logger.info(
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}, user_id={user_id}"
)
# Check user rate limit before processing
try:
# Only check executions from the last 24 hours for performance
current_running_count = get_db_client().get_graph_executions_count(
user_id=user_id,
graph_id=graph_id,
statuses=[ExecutionStatus.RUNNING],
created_time_gte=datetime.now(timezone.utc) - timedelta(hours=24),
)
if (
current_running_count
>= settings.config.max_concurrent_graph_executions_per_user
):
logger.warning(
f"[{self.service_name}] Rate limit exceeded for user {user_id} on graph {graph_id}: "
f"{current_running_count}/{settings.config.max_concurrent_graph_executions_per_user} running executions"
)
_ack_message(reject=True, requeue=True)
return
except Exception as e:
logger.error(
f"[{self.service_name}] Failed to check rate limit for user {user_id}: {e}, proceeding with execution"
)
# If rate limit check fails, proceed to avoid blocking executions
# Check for local duplicate execution first
if graph_exec_id in self.active_graph_runs:
logger.warning(
@@ -1471,11 +1548,12 @@ class ExecutionManager(AppProcess):
logger.warning(
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
)
_ack_message(reject=True, requeue=False)
else:
logger.warning(
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
)
_ack_message(reject=True, requeue=True)
_ack_message(reject=True, requeue=True)
return
self._execution_locks[graph_exec_id] = cluster_lock

View File

@@ -8,7 +8,7 @@ if TYPE_CHECKING:
# --8<-- [start:load_webhook_managers]
@cached(ttl_seconds=3600) # Cache webhook managers for 1 hour
@cached(ttl_seconds=3600)
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
webhook_managers = {}

View File

@@ -25,7 +25,11 @@ from backend.data.notifications import (
get_summary_params_type,
)
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import generate_unsubscribe_link
from backend.data.user import (
disable_all_user_notifications,
generate_unsubscribe_link,
set_user_email_verification,
)
from backend.notifications.email import EmailSender
from backend.util.clients import get_database_manager_async_client
from backend.util.logging import TruncatedLogger
@@ -38,7 +42,7 @@ from backend.util.service import (
endpoint_to_sync,
expose,
)
from backend.util.settings import Settings
from backend.util.settings import AppEnvironment, Settings
logger = TruncatedLogger(logging.getLogger(__name__), "[NotificationManager]")
settings = Settings()
@@ -124,6 +128,12 @@ def get_routing_key(event_type: NotificationType) -> str:
def queue_notification(event: NotificationEventModel) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
# Disable in production
if settings.config.app_env == AppEnvironment.PRODUCTION:
return NotificationResult(
success=True,
message="Queueing notifications is disabled in production",
)
try:
logger.debug(f"Received Request to queue {event=}")
@@ -151,6 +161,12 @@ def queue_notification(event: NotificationEventModel) -> NotificationResult:
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
# Disable in production
if settings.config.app_env == AppEnvironment.PRODUCTION:
return NotificationResult(
success=True,
message="Queueing notifications is disabled in production",
)
try:
logger.debug(f"Received Request to queue {event=}")
@@ -213,6 +229,9 @@ class NotificationManager(AppService):
@expose
async def queue_weekly_summary(self):
# disable in prod
if settings.config.app_env == AppEnvironment.PRODUCTION:
return
# Use the existing event loop instead of creating a new one with asyncio.run()
asyncio.create_task(self._queue_weekly_summary())
@@ -226,7 +245,9 @@ class NotificationManager(AppService):
logger.info(
f"Querying for active users between {start_time} and {current_time}"
)
users = await get_database_manager_async_client().get_active_user_ids_in_timerange(
users = await get_database_manager_async_client(
should_retry=False
).get_active_user_ids_in_timerange(
end_time=current_time.isoformat(),
start_time=start_time.isoformat(),
)
@@ -253,6 +274,9 @@ class NotificationManager(AppService):
async def process_existing_batches(
self, notification_types: list[NotificationType]
):
# disable in prod
if settings.config.app_env == AppEnvironment.PRODUCTION:
return
# Use the existing event loop instead of creating a new process
asyncio.create_task(self._process_existing_batches(notification_types))
@@ -266,15 +290,15 @@ class NotificationManager(AppService):
for notification_type in notification_types:
# Get all batches for this notification type
batches = (
await get_database_manager_async_client().get_all_batches_by_type(
notification_type
)
)
batches = await get_database_manager_async_client(
should_retry=False
).get_all_batches_by_type(notification_type)
for batch in batches:
# Check if batch has aged out
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
oldest_message = await get_database_manager_async_client(
should_retry=False
).get_user_notification_oldest_message_in_batch(
batch.user_id, notification_type
)
@@ -289,9 +313,9 @@ class NotificationManager(AppService):
# If batch has aged out, process it
if oldest_message.created_at + max_delay < current_time:
recipient_email = await get_database_manager_async_client().get_user_email_by_id(
batch.user_id
)
recipient_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_by_id(batch.user_id)
if not recipient_email:
logger.error(
@@ -308,21 +332,25 @@ class NotificationManager(AppService):
f"User {batch.user_id} does not want to receive {notification_type} notifications"
)
# Clear the batch
await get_database_manager_async_client().empty_user_notification_batch(
await get_database_manager_async_client(
should_retry=False
).empty_user_notification_batch(
batch.user_id, notification_type
)
continue
batch_data = await get_database_manager_async_client().get_user_notification_batch(
batch.user_id, notification_type
)
batch_data = await get_database_manager_async_client(
should_retry=False
).get_user_notification_batch(batch.user_id, notification_type)
if not batch_data or not batch_data.notifications:
logger.error(
f"Batch data not found for user {batch.user_id}"
)
# Clear the batch
await get_database_manager_async_client().empty_user_notification_batch(
await get_database_manager_async_client(
should_retry=False
).empty_user_notification_batch(
batch.user_id, notification_type
)
continue
@@ -358,7 +386,9 @@ class NotificationManager(AppService):
)
# Clear the batch
await get_database_manager_async_client().empty_user_notification_batch(
await get_database_manager_async_client(
should_retry=False
).empty_user_notification_batch(
batch.user_id, notification_type
)
@@ -413,15 +443,13 @@ class NotificationManager(AppService):
self, user_id: str, event_type: NotificationType
) -> bool:
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
validated_email = (
await get_database_manager_async_client().get_user_email_verification(
user_id
)
)
validated_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_verification(user_id)
preference = (
await get_database_manager_async_client().get_user_notification_preference(
user_id
)
await get_database_manager_async_client(
should_retry=False
).get_user_notification_preference(user_id)
).preferences.get(event_type, True)
# only if both are true, should we email this person
return validated_email and preference
@@ -437,7 +465,9 @@ class NotificationManager(AppService):
try:
# Get summary data from the database
summary_data = await get_database_manager_async_client().get_user_execution_summary_data(
summary_data = await get_database_manager_async_client(
should_retry=False
).get_user_execution_summary_data(
user_id=user_id,
start_time=params.start_date,
end_time=params.end_date,
@@ -524,13 +554,13 @@ class NotificationManager(AppService):
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
) -> bool:
await get_database_manager_async_client().create_or_add_to_user_notification_batch(
user_id, event_type, event
)
await get_database_manager_async_client(
should_retry=False
).create_or_add_to_user_notification_batch(user_id, event_type, event)
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
user_id, event_type
)
oldest_message = await get_database_manager_async_client(
should_retry=False
).get_user_notification_oldest_message_in_batch(user_id, event_type)
if not oldest_message:
logger.error(
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
@@ -580,11 +610,9 @@ class NotificationManager(AppService):
return False
logger.debug(f"Processing immediate notification: {event}")
recipient_email = (
await get_database_manager_async_client().get_user_email_by_id(
event.user_id
)
)
recipient_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -619,11 +647,9 @@ class NotificationManager(AppService):
return False
logger.info(f"Processing batch notification: {event}")
recipient_email = (
await get_database_manager_async_client().get_user_email_by_id(
event.user_id
)
)
recipient_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False
@@ -642,11 +668,9 @@ class NotificationManager(AppService):
if not should_send:
logger.info("Batch not old enough to send")
return False
batch = (
await get_database_manager_async_client().get_user_notification_batch(
event.user_id, event.type
)
)
batch = await get_database_manager_async_client(
should_retry=False
).get_user_notification_batch(event.user_id, event.type)
if not batch or not batch.notifications:
logger.error(f"Batch not found for user {event.user_id}")
return False
@@ -657,6 +681,7 @@ class NotificationManager(AppService):
get_notif_data_type(db_event.type)
].model_validate(
{
"id": db_event.id, # Include ID from database
"user_id": event.user_id,
"type": db_event.type,
"data": db_event.data,
@@ -679,6 +704,9 @@ class NotificationManager(AppService):
chunk_sent = False
for attempt_size in [chunk_size, 50, 25, 10, 5, 1]:
chunk = batch_messages[i : i + attempt_size]
chunk_ids = [
msg.id for msg in chunk if msg.id
] # Extract IDs for removal
try:
# Try to render the email to check its size
@@ -705,6 +733,23 @@ class NotificationManager(AppService):
user_unsub_link=unsub_link,
)
# Remove successfully sent notifications immediately
if chunk_ids:
try:
await get_database_manager_async_client(
should_retry=False
).remove_notifications_from_batch(
event.user_id, event.type, chunk_ids
)
logger.info(
f"Removed {len(chunk_ids)} sent notifications from batch"
)
except Exception as e:
logger.error(
f"Failed to remove sent notifications: {e}"
)
# Continue anyway - better to risk duplicates than lose emails
# Track successful sends
successfully_sent_count += len(chunk)
@@ -722,13 +767,137 @@ class NotificationManager(AppService):
i += len(chunk)
chunk_sent = True
break
else:
# Message is too large even after size reduction
if attempt_size == 1:
logger.error(
f"Failed to send notification at index {i}: "
f"Single notification exceeds email size limit "
f"({len(test_message):,} chars > {MAX_EMAIL_SIZE:,} chars). "
f"Removing permanently from batch - will not retry."
)
# Remove the oversized notification permanently - it will NEVER fit
if chunk_ids:
try:
await get_database_manager_async_client(
should_retry=False
).remove_notifications_from_batch(
event.user_id, event.type, chunk_ids
)
logger.info(
f"Removed oversized notification {chunk_ids[0]} from batch permanently"
)
except Exception as e:
logger.error(
f"Failed to remove oversized notification: {e}"
)
failed_indices.append(i)
i += 1
chunk_sent = True
break
# Try smaller chunk size
continue
except Exception as e:
# Check if it's a Postmark API error
if attempt_size == 1:
# Even single notification is too large
logger.error(
f"Single notification too large to send: {e}. "
f"Skipping notification at index {i}"
)
# Single notification failed - determine the actual cause
error_message = str(e).lower()
error_type = type(e).__name__
# Check for HTTP 406 - Inactive recipient (common in Postmark errors)
if "406" in error_message or "inactive" in error_message:
logger.warning(
f"Failed to send notification at index {i}: "
f"Recipient marked as inactive by Postmark. "
f"Error: {e}. Disabling ALL notifications for this user."
)
# 1. Mark email as unverified
try:
await set_user_email_verification(
event.user_id, False
)
logger.info(
f"Set email verification to false for user {event.user_id}"
)
except Exception as deactivation_error:
logger.error(
f"Failed to deactivate email for user {event.user_id}: "
f"{deactivation_error}"
)
# 2. Disable all notification preferences
try:
await disable_all_user_notifications(event.user_id)
logger.info(
f"Disabled all notification preferences for user {event.user_id}"
)
except Exception as disable_error:
logger.error(
f"Failed to disable notification preferences: {disable_error}"
)
# 3. Clear ALL notification batches for this user
try:
await get_database_manager_async_client(
should_retry=False
).clear_all_user_notification_batches(event.user_id)
logger.info(
f"Cleared ALL notification batches for user {event.user_id}"
)
except Exception as remove_error:
logger.error(
f"Failed to clear batches for inactive recipient: {remove_error}"
)
# Stop processing - we've nuked everything for this user
return True
# Check for HTTP 422 - Malformed data
elif (
"422" in error_message
or "unprocessable" in error_message
):
logger.error(
f"Failed to send notification at index {i}: "
f"Malformed notification data rejected by Postmark. "
f"Error: {e}. Removing from batch permanently."
)
# Remove from batch - 422 means bad data that won't fix itself
if chunk_ids:
try:
await get_database_manager_async_client(
should_retry=False
).remove_notifications_from_batch(
event.user_id, event.type, chunk_ids
)
logger.info(
"Removed malformed notification from batch permanently"
)
except Exception as remove_error:
logger.error(
f"Failed to remove malformed notification: {remove_error}"
)
# Check if it's a ValueError for size limit
elif (
isinstance(e, ValueError)
and "too large" in error_message
):
logger.error(
f"Failed to send notification at index {i}: "
f"Notification size exceeds email limit. "
f"Error: {e}. Skipping this notification."
)
# Other API errors
else:
logger.error(
f"Failed to send notification at index {i}: "
f"Email API error ({error_type}): {e}. "
f"Skipping this notification."
)
failed_indices.append(i)
i += 1
chunk_sent = True
@@ -742,18 +911,20 @@ class NotificationManager(AppService):
failed_indices.append(i)
i += 1
# Only empty the batch if ALL notifications were sent successfully
if successfully_sent_count == len(batch_messages):
# Check what remains in the batch (notifications are removed as sent)
remaining_batch = await get_database_manager_async_client(
should_retry=False
).get_user_notification_batch(event.user_id, event.type)
if not remaining_batch or not remaining_batch.notifications:
logger.info(
f"Successfully sent all {successfully_sent_count} notifications, clearing batch"
)
await get_database_manager_async_client().empty_user_notification_batch(
event.user_id, event.type
f"All {successfully_sent_count} notifications sent and removed from batch"
)
else:
remaining_count = len(remaining_batch.notifications)
logger.warning(
f"Only sent {successfully_sent_count} of {len(batch_messages)} notifications. "
f"Failed indices: {failed_indices}. Batch will be retained for retry."
f"Sent {successfully_sent_count} notifications. "
f"{remaining_count} remain in batch for retry due to errors."
)
return True
except Exception as e:
@@ -771,11 +942,9 @@ class NotificationManager(AppService):
logger.info(f"Processing summary notification: {model}")
recipient_email = (
await get_database_manager_async_client().get_user_email_by_id(
event.user_id
)
)
recipient_email = await get_database_manager_async_client(
should_retry=False
).get_user_email_by_id(event.user_id)
if not recipient_email:
logger.error(f"User email not found for user {event.user_id}")
return False

View File

@@ -0,0 +1,598 @@
"""Tests for notification error handling in NotificationManager."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from prisma.enums import NotificationType
from backend.data.notifications import AgentRunData, NotificationEventModel
from backend.notifications.notifications import NotificationManager
class TestNotificationErrorHandling:
"""Test cases for notification error handling in NotificationManager."""
@pytest.fixture
def notification_manager(self):
"""Create a NotificationManager instance for testing."""
with patch("backend.notifications.notifications.AppService.__init__"):
manager = NotificationManager()
manager.email_sender = MagicMock()
# Mock the _get_template method used by _process_batch
template_mock = Mock()
template_mock.base_template = "base"
template_mock.subject_template = "subject"
template_mock.body_template = "body"
manager.email_sender._get_template = Mock(return_value=template_mock)
# Mock the formatter
manager.email_sender.formatter = Mock()
manager.email_sender.formatter.format_email = Mock(
return_value=("subject", "body content")
)
manager.email_sender.formatter.env = Mock()
manager.email_sender.formatter.env.globals = {
"base_url": "http://example.com"
}
return manager
@pytest.fixture
def sample_batch_event(self):
"""Create a sample batch event for testing."""
return NotificationEventModel(
type=NotificationType.AGENT_RUN,
user_id="user_1",
created_at=datetime.now(timezone.utc),
data=AgentRunData(
agent_name="Test Agent",
credits_used=10.0,
execution_time=5.0,
node_count=3,
graph_id="graph_1",
outputs=[],
),
)
@pytest.fixture
def sample_batch_notifications(self):
"""Create sample batch notifications for testing."""
notifications = []
for i in range(3):
notification = Mock()
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
return notifications
@pytest.mark.asyncio
async def test_406_stops_all_processing_for_user(
self, notification_manager, sample_batch_event
):
"""Test that 406 inactive recipient error stops ALL processing for that user."""
with patch("backend.notifications.notifications.logger"), patch(
"backend.notifications.notifications.set_user_email_verification",
new_callable=AsyncMock,
) as mock_set_verification, patch(
"backend.notifications.notifications.disable_all_user_notifications",
new_callable=AsyncMock,
) as mock_disable_all, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
return_value=Mock(notifications=notifications)
)
mock_db.clear_all_user_notification_batches = AsyncMock()
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Track calls
call_count = [0]
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
current_call = call_count[0]
call_count[0] += 1
# First two succeed, third hits 406
if current_call < 2:
return None
else:
raise Exception("Recipient marked as inactive (406)")
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
# Only 3 calls should have been made (2 successful, 1 failed with 406)
assert call_count[0] == 3
# User should be deactivated
mock_set_verification.assert_called_once_with("user_1", False)
mock_disable_all.assert_called_once_with("user_1")
mock_db.clear_all_user_notification_batches.assert_called_once_with(
"user_1"
)
# No further processing should occur after 406
@pytest.mark.asyncio
async def test_422_permanently_removes_malformed_notification(
self, notification_manager, sample_batch_event
):
"""Test that 422 error permanently removes the malformed notification from batch and continues with others."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
side_effect=[
Mock(notifications=notifications),
Mock(notifications=[]), # Empty after processing
]
)
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Track calls
call_count = [0]
successful_indices = []
removed_notification_ids = []
# Capture what gets removed
def remove_side_effect(user_id, notif_type, notif_ids):
removed_notification_ids.extend(notif_ids)
return None
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
current_call = call_count[0]
call_count[0] += 1
# Index 2 has malformed data (422)
if current_call == 2:
raise Exception(
"Unprocessable entity (422): Malformed email data"
)
else:
successful_indices.append(current_call)
return None
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
assert call_count[0] == 5 # All 5 attempted
assert len(successful_indices) == 4 # 4 succeeded (all except index 2)
assert 2 not in successful_indices # Index 2 failed
# Verify 422 error was logged
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
assert any(
"422" in call or "malformed" in call.lower() for call in error_calls
)
# Verify all notifications were removed (4 successful + 1 malformed)
assert mock_db.remove_notifications_from_batch.call_count == 5
assert (
"notif_2" in removed_notification_ids
) # Malformed one was removed permanently
@pytest.mark.asyncio
async def test_oversized_notification_permanently_removed(
self, notification_manager, sample_batch_event
):
"""Test that oversized notifications are permanently removed from batch but others continue."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
side_effect=[
Mock(notifications=notifications),
Mock(notifications=[]), # Empty after processing
]
)
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Override formatter to simulate oversized on index 3
# original_format = notification_manager.email_sender.formatter.format_email
def format_side_effect(*args, **kwargs):
# Check if we're formatting index 3
data = kwargs.get("data", {}).get("notifications", [])
if data and len(data) == 1:
# Check notification content to identify index 3
if any(
"Test Agent 3" in str(n.data)
for n in data
if hasattr(n, "data")
):
# Return oversized message for index 3
return ("subject", "x" * 5_000_000) # Over 4.5MB limit
return ("subject", "normal sized content")
notification_manager.email_sender.formatter.format_email = Mock(
side_effect=format_side_effect
)
# Track calls
successful_indices = []
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
# Track which notification was sent based on content
for i, notif in enumerate(notifications):
if any(
f"Test Agent {i}" in str(n.data)
for n in data
if hasattr(n, "data")
):
successful_indices.append(i)
return None
return None
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
assert (
len(successful_indices) == 4
) # Only 4 sent (index 3 skipped due to size)
assert 3 not in successful_indices # Index 3 was not sent
# Verify oversized error was logged
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
assert any(
"exceeds email size limit" in call or "oversized" in call.lower()
for call in error_calls
)
@pytest.mark.asyncio
async def test_generic_api_error_keeps_notification_for_retry(
self, notification_manager, sample_batch_event
):
"""Test that generic API errors keep notifications in batch for retry while others continue."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Notification that failed with generic error
failed_notifications = [notifications[1]] # Only index 1 remains for retry
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
side_effect=[
Mock(notifications=notifications),
Mock(
notifications=failed_notifications
), # Failed ones remain for retry
]
)
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Track calls
successful_indices = []
failed_indices = []
removed_notification_ids = []
# Capture what gets removed
def remove_side_effect(user_id, notif_type, notif_ids):
removed_notification_ids.extend(notif_ids)
return None
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
# Track which notification based on content
for i, notif in enumerate(notifications):
if any(
f"Test Agent {i}" in str(n.data)
for n in data
if hasattr(n, "data")
):
# Index 1 has generic API error
if i == 1:
failed_indices.append(i)
raise Exception("Network timeout - temporary failure")
else:
successful_indices.append(i)
return None
return None
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
assert len(successful_indices) == 4 # 4 succeeded (0, 2, 3, 4)
assert len(failed_indices) == 1 # 1 failed
assert 1 in failed_indices # Index 1 failed
# Verify generic error was logged
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
assert any(
"api error" in call.lower() or "skipping" in call.lower()
for call in error_calls
)
# Only successful ones should be removed from batch (failed one stays for retry)
assert mock_db.remove_notifications_from_batch.call_count == 4
assert (
"notif_1" not in removed_notification_ids
) # Failed one NOT removed (stays for retry)
assert "notif_0" in removed_notification_ids # Successful one removed
assert "notif_2" in removed_notification_ids # Successful one removed
assert "notif_3" in removed_notification_ids # Successful one removed
assert "notif_4" in removed_notification_ids # Successful one removed
@pytest.mark.asyncio
async def test_batch_all_notifications_sent_successfully(
self, notification_manager, sample_batch_event
):
"""Test successful batch processing where all notifications are sent without errors."""
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
"backend.notifications.notifications.get_database_manager_async_client"
) as mock_db_client, patch(
"backend.notifications.notifications.generate_unsubscribe_link"
) as mock_unsub_link:
# Create batch of 5 notifications
notifications = []
for i in range(5):
notification = Mock()
notification.id = f"notif_{i}"
notification.type = NotificationType.AGENT_RUN
notification.data = {
"agent_name": f"Test Agent {i}",
"credits_used": 10.0 * (i + 1),
"execution_time": 5.0 * (i + 1),
"node_count": 3 + i,
"graph_id": f"graph_{i}",
"outputs": [],
}
notification.created_at = datetime.now(timezone.utc)
notifications.append(notification)
# Setup mocks
mock_db = mock_db_client.return_value
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
mock_db.get_user_notification_batch = AsyncMock(
side_effect=[
Mock(notifications=notifications),
Mock(notifications=[]), # Empty after all sent successfully
]
)
mock_db.remove_notifications_from_batch = AsyncMock()
mock_unsub_link.return_value = "http://example.com/unsub"
# Mock internal methods
notification_manager._should_email_user_based_on_preference = AsyncMock(
return_value=True
)
notification_manager._should_batch = AsyncMock(return_value=True)
notification_manager._parse_message = Mock(return_value=sample_batch_event)
# Track successful sends
successful_indices = []
removed_notification_ids = []
# Capture what gets removed
def remove_side_effect(user_id, notif_type, notif_ids):
removed_notification_ids.extend(notif_ids)
return None
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
def send_side_effect(*args, **kwargs):
data = kwargs.get("data", [])
if isinstance(data, list) and len(data) == 1:
# Track which notification was sent
for i, notif in enumerate(notifications):
if any(
f"Test Agent {i}" in str(n.data)
for n in data
if hasattr(n, "data")
):
successful_indices.append(i)
return None
return None # Success
# Force single processing
raise Exception("Force single processing")
notification_manager.email_sender.send_templated.side_effect = (
send_side_effect
)
# Act
result = await notification_manager._process_batch(
sample_batch_event.model_dump_json()
)
# Assert
assert result is True
# All 5 notifications should be sent successfully
assert len(successful_indices) == 5
assert successful_indices == [0, 1, 2, 3, 4]
# All notifications should be removed from batch
assert mock_db.remove_notifications_from_batch.call_count == 5
assert len(removed_notification_ids) == 5
for i in range(5):
assert f"notif_{i}" in removed_notification_ids
# No errors should be logged
assert mock_logger.error.call_count == 0
# Info message about successful sends should be logged
info_calls = [call[0][0] for call in mock_logger.info.call_args_list]
assert any("sent and removed" in call.lower() for call in info_calls)

View File

@@ -1,86 +0,0 @@
"""
Shared cache configuration constants.
This module defines all page_size defaults used across the application.
By centralizing these values, we ensure that cache invalidation always
uses the same page_size as the routes that populate the cache.
CRITICAL: If you change any of these values, the tests in
test_cache_invalidation_consistency.py will fail to remind you to
update all dependent code.
"""
# V1 API (legacy) page sizes
V1_GRAPHS_PAGE_SIZE = 250
"""Default page size for listing user graphs in v1 API."""
V1_LIBRARY_AGENTS_PAGE_SIZE = 10
"""Default page size for library agents in v1 API."""
V1_GRAPH_EXECUTIONS_PAGE_SIZE = 25
"""Default page size for graph executions in v1 API."""
# V2 Store API page sizes
V2_STORE_AGENTS_PAGE_SIZE = 20
"""Default page size for store agents listing."""
V2_STORE_CREATORS_PAGE_SIZE = 20
"""Default page size for store creators listing."""
V2_STORE_SUBMISSIONS_PAGE_SIZE = 20
"""Default page size for user submissions listing."""
V2_MY_AGENTS_PAGE_SIZE = 20
"""Default page size for user's own agents listing."""
# V2 Library API page sizes
V2_LIBRARY_AGENTS_PAGE_SIZE = 10
"""Default page size for library agents listing in v2 API."""
V2_LIBRARY_PRESETS_PAGE_SIZE = 20
"""Default page size for library presets listing."""
# Alternative page sizes (for backward compatibility or special cases)
V2_LIBRARY_PRESETS_ALT_PAGE_SIZE = 10
"""
Alternative page size for library presets.
Some clients may use this smaller page size, so cache clearing must handle both.
"""
V2_GRAPH_EXECUTIONS_ALT_PAGE_SIZE = 10
"""
Alternative page size for graph executions.
Some clients may use this smaller page size, so cache clearing must handle both.
"""
# Cache clearing configuration
MAX_PAGES_TO_CLEAR = 20
"""
Maximum number of pages to clear when invalidating paginated caches.
This prevents infinite loops while ensuring we clear most cached pages.
For users with more than 20 pages, those pages will expire naturally via TTL.
"""
def get_page_sizes_for_clearing(
primary_page_size: int, alt_page_size: int | None = None
) -> list[int]:
"""
Get all page_size values that should be cleared for a given cache.
Args:
primary_page_size: The main page_size used by the route
alt_page_size: Optional alternative page_size if multiple clients use different sizes
Returns:
List of page_size values to clear
Example:
>>> get_page_sizes_for_clearing(20)
[20]
>>> get_page_sizes_for_clearing(20, 10)
[20, 10]
"""
if alt_page_size is None:
return [primary_page_size]
return [primary_page_size, alt_page_size]

View File

@@ -14,19 +14,49 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "test-user-id"
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "admin-user-id"
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "target-user-id"
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
@pytest.fixture

View File

@@ -64,7 +64,7 @@ class LoginResponse(BaseModel):
state_token: str
@router.get("/{provider}/login")
@router.get("/{provider}/login", summary="Initiate OAuth flow")
async def login(
provider: Annotated[
ProviderName, Path(title="The provider to initiate an OAuth flow for")
@@ -102,7 +102,7 @@ class CredentialsMetaResponse(BaseModel):
)
@router.post("/{provider}/callback")
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback(
provider: Annotated[
ProviderName, Path(title="The target provider for this OAuth exchange")

View File

@@ -1,154 +0,0 @@
"""
Cache functions for main V1 API endpoints.
This module contains all caching decorators and helpers for the V1 API,
separated from the main routes for better organization and maintainability.
"""
from typing import Sequence
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
from backend.data.block import get_blocks
from backend.util.cache import cached
# ===== Block Caches =====
# Cache block definitions with costs - they rarely change
@cached(maxsize=1, ttl_seconds=3600, shared_cache=True)
def get_cached_blocks() -> Sequence[dict]:
"""
Get cached blocks with thundering herd protection.
Uses cached decorator to prevent multiple concurrent requests
from all executing the expensive block loading operation.
"""
from backend.data.credit import get_block_cost
block_classes = get_blocks()
result = []
for block_class in block_classes.values():
block_instance = block_class()
if not block_instance.disabled:
# Get costs for this specific block class without creating another instance
costs = get_block_cost(block_instance)
result.append({**block_instance.to_dict(), "costs": costs})
return result
# ===== Graph Caches =====
# Cache user's graphs list for 15 minutes
@cached(maxsize=1000, ttl_seconds=900, shared_cache=True)
async def get_cached_graphs(
user_id: str,
page: int,
page_size: int,
):
"""Cached helper to get user's graphs."""
return await graph_db.list_graphs_paginated(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache individual graph details for 30 minutes
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_graph(
graph_id: str,
version: int | None,
user_id: str,
):
"""Cached helper to get graph details."""
return await graph_db.get_graph(
graph_id=graph_id,
version=version,
user_id=user_id,
include_subgraphs=True, # needed to construct full credentials input schema
)
# Cache graph versions for 30 minutes
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_graph_all_versions(
graph_id: str,
user_id: str,
) -> Sequence[graph_db.GraphModel]:
"""Cached helper to get all versions of a graph."""
return await graph_db.get_graph_all_versions(
graph_id=graph_id,
user_id=user_id,
)
# ===== Execution Caches =====
# Cache graph executions for 10 seconds.
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
async def get_cached_graph_executions(
graph_id: str,
user_id: str,
page: int,
page_size: int,
):
"""Cached helper to get graph executions."""
return await execution_db.get_graph_executions_paginated(
graph_id=graph_id,
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache all user executions for 10 seconds.
@cached(maxsize=500, ttl_seconds=10, shared_cache=True)
async def get_cached_graphs_executions(
user_id: str,
page: int,
page_size: int,
):
"""Cached helper to get all user's graph executions."""
return await execution_db.get_graph_executions_paginated(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache individual execution details for 10 seconds.
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
async def get_cached_graph_execution(
graph_exec_id: str,
user_id: str,
):
"""Cached helper to get graph execution details."""
return await execution_db.get_graph_execution(
user_id=user_id,
execution_id=graph_exec_id,
include_node_executions=False,
)
# ===== User Preference Caches =====
# Cache user timezone for 1 hour
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
async def get_cached_user_timezone(user_id: str):
"""Cached helper to get user timezone."""
user = await user_db.get_user_by_id(user_id)
return {"timezone": user.timezone if user else "UTC"}
# Cache user preferences for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_user_preferences(user_id: str):
"""Cached helper to get user notification preferences."""
return await user_db.get_user_notification_preference(user_id)

View File

@@ -1,376 +0,0 @@
"""
Tests for cache invalidation in V1 API routes.
This module tests that caches are properly invalidated when data is modified
through POST, PUT, PATCH, and DELETE operations.
"""
import uuid
from unittest.mock import AsyncMock, patch
import pytest
import backend.server.routers.cache as cache
from backend.data import graph as graph_db
@pytest.fixture
def mock_user_id():
"""Generate a mock user ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def mock_graph_id():
"""Generate a mock graph ID for testing."""
return str(uuid.uuid4())
class TestGraphCacheInvalidation:
"""Test cache invalidation for graph operations."""
@pytest.mark.asyncio
async def test_create_graph_clears_list_cache(self, mock_user_id):
"""Test that creating a graph clears the graphs list cache."""
# Setup
cache.get_cached_graphs.cache_clear()
# Pre-populate cache
with patch.object(
graph_db, "list_graphs_paginated", new_callable=AsyncMock
) as mock_list:
# Use a simple dict instead of MagicMock to make it pickleable
mock_list.return_value = {
"graphs": [],
"total_count": 0,
"page": 1,
"page_size": 250,
}
# First call should hit the database
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_list.call_count == 1
# Second call should use cache
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_list.call_count == 1 # Still 1, used cache
# Simulate cache invalidation (what happens in create_new_graph)
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
# Next call should hit database again
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_list.call_count == 2 # Incremented, cache was cleared
@pytest.mark.asyncio
async def test_delete_graph_clears_multiple_caches(
self, mock_user_id, mock_graph_id
):
"""Test that deleting a graph clears all related caches."""
# Clear all caches first
cache.get_cached_graphs.cache_clear()
cache.get_cached_graph.cache_clear()
cache.get_cached_graph_all_versions.cache_clear()
cache.get_cached_graph_executions.cache_clear()
# Setup mocks
with (
patch.object(
graph_db, "list_graphs_paginated", new_callable=AsyncMock
) as mock_list,
patch.object(graph_db, "get_graph", new_callable=AsyncMock) as mock_get,
patch.object(
graph_db, "get_graph_all_versions", new_callable=AsyncMock
) as mock_versions,
):
mock_list.return_value = {
"graphs": [],
"total_count": 0,
"page": 1,
"page_size": 250,
}
mock_get.return_value = {"id": mock_graph_id}
mock_versions.return_value = []
# Pre-populate all caches (use consistent argument style)
await cache.get_cached_graphs(mock_user_id, 1, 250)
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
initial_calls = {
"list": mock_list.call_count,
"get": mock_get.call_count,
"versions": mock_versions.call_count,
}
# Use cached values (no additional DB calls)
await cache.get_cached_graphs(mock_user_id, 1, 250)
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
# Verify cache was used
assert mock_list.call_count == initial_calls["list"]
assert mock_get.call_count == initial_calls["get"]
assert mock_versions.call_count == initial_calls["versions"]
# Simulate delete_graph cache invalidation
# Use positional arguments for cache_delete to match how we called the functions
result1 = cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
result2 = cache.get_cached_graph.cache_delete(
mock_graph_id, None, mock_user_id
)
result3 = cache.get_cached_graph_all_versions.cache_delete(
mock_graph_id, mock_user_id
)
# Verify that the cache entries were actually deleted
assert result1, "Failed to delete graphs cache entry"
assert result2, "Failed to delete graph cache entry"
assert result3, "Failed to delete graph versions cache entry"
# Next calls should hit database
await cache.get_cached_graphs(mock_user_id, 1, 250)
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
# Verify database was called again
assert mock_list.call_count == initial_calls["list"] + 1
assert mock_get.call_count == initial_calls["get"] + 1
assert mock_versions.call_count == initial_calls["versions"] + 1
@pytest.mark.asyncio
async def test_update_graph_clears_caches(self, mock_user_id, mock_graph_id):
"""Test that updating a graph clears the appropriate caches."""
# Clear caches
cache.get_cached_graph.cache_clear()
cache.get_cached_graph_all_versions.cache_clear()
cache.get_cached_graphs.cache_clear()
with (
patch.object(graph_db, "get_graph", new_callable=AsyncMock) as mock_get,
patch.object(
graph_db, "get_graph_all_versions", new_callable=AsyncMock
) as mock_versions,
patch.object(
graph_db, "list_graphs_paginated", new_callable=AsyncMock
) as mock_list,
):
mock_get.return_value = {"id": mock_graph_id, "version": 1}
mock_versions.return_value = [{"version": 1}]
mock_list.return_value = {
"graphs": [],
"total_count": 0,
"page": 1,
"page_size": 250,
}
# Populate caches
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
await cache.get_cached_graphs(mock_user_id, 1, 250)
initial_calls = {
"get": mock_get.call_count,
"versions": mock_versions.call_count,
"list": mock_list.call_count,
}
# Verify cache is being used
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_get.call_count == initial_calls["get"]
assert mock_versions.call_count == initial_calls["versions"]
assert mock_list.call_count == initial_calls["list"]
# Simulate update_graph cache invalidation
cache.get_cached_graph.cache_delete(mock_graph_id, None, mock_user_id)
cache.get_cached_graph_all_versions.cache_delete(
mock_graph_id, mock_user_id
)
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
# Next calls should hit database
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_get.call_count == initial_calls["get"] + 1
assert mock_versions.call_count == initial_calls["versions"] + 1
assert mock_list.call_count == initial_calls["list"] + 1
class TestUserPreferencesCacheInvalidation:
"""Test cache invalidation for user preferences operations."""
@pytest.mark.asyncio
async def test_update_preferences_clears_cache(self, mock_user_id):
"""Test that updating preferences clears the preferences cache."""
# Clear cache
cache.get_cached_user_preferences.cache_clear()
with patch.object(
cache.user_db, "get_user_notification_preference", new_callable=AsyncMock
) as mock_get_prefs:
mock_prefs = {"email_notifications": True, "push_notifications": False}
mock_get_prefs.return_value = mock_prefs
# First call hits database
result1 = await cache.get_cached_user_preferences(mock_user_id)
assert mock_get_prefs.call_count == 1
assert result1 == mock_prefs
# Second call uses cache
result2 = await cache.get_cached_user_preferences(mock_user_id)
assert mock_get_prefs.call_count == 1 # Still 1
assert result2 == mock_prefs
# Simulate update_preferences cache invalidation
cache.get_cached_user_preferences.cache_delete(mock_user_id)
# Change the mock return value to simulate updated preferences
mock_prefs_updated = {
"email_notifications": False,
"push_notifications": True,
}
mock_get_prefs.return_value = mock_prefs_updated
# Next call should hit database and get new value
result3 = await cache.get_cached_user_preferences(mock_user_id)
assert mock_get_prefs.call_count == 2
assert result3 == mock_prefs_updated
@pytest.mark.asyncio
async def test_timezone_cache_operations(self, mock_user_id):
"""Test timezone cache and its operations."""
# Clear cache
cache.get_cached_user_timezone.cache_clear()
with patch.object(
cache.user_db, "get_user_by_id", new_callable=AsyncMock
) as mock_get_user:
# Use a simple object that supports attribute access
class MockUser:
def __init__(self, timezone):
self.timezone = timezone
mock_user = MockUser("America/New_York")
mock_get_user.return_value = mock_user
# First call hits database
result1 = await cache.get_cached_user_timezone(mock_user_id)
assert mock_get_user.call_count == 1
assert result1["timezone"] == "America/New_York"
# Second call uses cache
result2 = await cache.get_cached_user_timezone(mock_user_id)
assert mock_get_user.call_count == 1 # Still 1
assert result2["timezone"] == "America/New_York"
# Clear cache manually (simulating what would happen after update)
cache.get_cached_user_timezone.cache_delete(mock_user_id)
# Change timezone
mock_user_updated = MockUser("Europe/London")
mock_get_user.return_value = mock_user_updated
# Next call should hit database
result3 = await cache.get_cached_user_timezone(mock_user_id)
assert mock_get_user.call_count == 2
assert result3["timezone"] == "Europe/London"
class TestExecutionCacheInvalidation:
"""Test cache invalidation for execution operations."""
@pytest.mark.asyncio
async def test_execution_cache_cleared_on_graph_delete(
self, mock_user_id, mock_graph_id
):
"""Test that execution caches are cleared when a graph is deleted."""
# Clear cache
cache.get_cached_graph_executions.cache_clear()
with patch.object(
cache.execution_db, "get_graph_executions_paginated", new_callable=AsyncMock
) as mock_exec:
mock_exec.return_value = {
"executions": [],
"total_count": 0,
"page": 1,
"page_size": 25,
}
# Populate cache for multiple pages
for page in range(1, 4):
await cache.get_cached_graph_executions(
mock_graph_id, mock_user_id, page, 25
)
initial_calls = mock_exec.call_count
# Verify cache is used
for page in range(1, 4):
await cache.get_cached_graph_executions(
mock_graph_id, mock_user_id, page, 25
)
assert mock_exec.call_count == initial_calls # No new calls
# Simulate graph deletion clearing execution caches
for page in range(1, 10): # Clear more pages as done in delete_graph
cache.get_cached_graph_executions.cache_delete(
mock_graph_id, mock_user_id, page, 25
)
# Next calls should hit database
for page in range(1, 4):
await cache.get_cached_graph_executions(
mock_graph_id, mock_user_id, page, 25
)
assert mock_exec.call_count == initial_calls + 3 # 3 new calls
class TestCacheInfo:
"""Test cache information and metrics."""
def test_cache_info_returns_correct_metrics(self):
"""Test that cache_info returns correct metrics."""
# Clear all caches
cache.get_cached_graphs.cache_clear()
cache.get_cached_graph.cache_clear()
# Get initial info
info_graphs = cache.get_cached_graphs.cache_info()
info_graph = cache.get_cached_graph.cache_info()
assert info_graphs["size"] == 0
assert info_graph["size"] == 0
# Note: We can't directly test cache population without real async context,
# but we can verify the cache_info structure
assert "size" in info_graphs
assert "maxsize" in info_graphs
assert "ttl_seconds" in info_graphs
def test_cache_clear_removes_all_entries(self):
"""Test that cache_clear removes all entries."""
# This test verifies the cache_clear method exists and can be called
cache.get_cached_graphs.cache_clear()
cache.get_cached_graph.cache_clear()
cache.get_cached_graph_all_versions.cache_clear()
cache.get_cached_graph_executions.cache_clear()
cache.get_cached_graphs_executions.cache_clear()
cache.get_cached_user_preferences.cache_clear()
cache.get_cached_user_timezone.cache_clear()
# After clear, all caches should be empty
assert cache.get_cached_graphs.cache_info()["size"] == 0
assert cache.get_cached_graph.cache_info()["size"] == 0
assert cache.get_cached_graph_all_versions.cache_info()["size"] == 0
assert cache.get_cached_graph_executions.cache_info()["size"] == 0
assert cache.get_cached_graphs_executions.cache_info()["size"] == 0
assert cache.get_cached_user_preferences.cache_info()["size"] == 0
assert cache.get_cached_user_timezone.cache_info()["size"] == 0

View File

@@ -28,11 +28,8 @@ from pydantic import BaseModel
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
import backend.server.cache_config as cache_config
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.routers.cache as cache
import backend.server.v2.library.cache as library_cache
import backend.server.v2.library.db as library_db
from backend.data import api_key as api_key_db
from backend.data import execution as execution_db
@@ -42,6 +39,7 @@ from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,
TransactionHistory,
UserCredit,
get_auto_top_up,
get_user_credit_model,
set_auto_top_up,
@@ -59,6 +57,7 @@ from backend.data.onboarding import (
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
update_user_notification_preference,
update_user_timezone,
@@ -85,6 +84,7 @@ from backend.server.model import (
UpdateTimezoneRequest,
UploadFileResponse,
)
from backend.util.cache import cached
from backend.util.clients import get_scheduler_client
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.exceptions import GraphValidationError, NotFoundError
@@ -108,9 +108,6 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
settings = Settings()
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
# Define the API routes
v1_router = APIRouter()
@@ -169,9 +166,7 @@ async def get_user_timezone_route(
) -> TimezoneResponse:
"""Get user timezone setting."""
user = await get_or_create_user(user_data)
# Use cached timezone for subsequent calls
result = await cache.get_cached_user_timezone(user.id)
return TimezoneResponse(timezone=result["timezone"])
return TimezoneResponse(timezone=user.timezone)
@v1_router.post(
@@ -185,7 +180,6 @@ async def update_user_timezone_route(
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
user = await update_user_timezone(user_id, str(request.timezone))
cache.get_cached_user_timezone.cache_delete(user_id)
return TimezoneResponse(timezone=user.timezone)
@@ -198,7 +192,7 @@ async def update_user_timezone_route(
async def get_preferences(
user_id: Annotated[str, Security(get_user_id)],
) -> NotificationPreference:
preferences = await cache.get_cached_user_preferences(user_id)
preferences = await get_user_notification_preference(user_id)
return preferences
@@ -213,10 +207,6 @@ async def update_preferences(
preferences: NotificationPreferenceDTO = Body(...),
) -> NotificationPreference:
output = await update_user_notification_preference(user_id, preferences)
# Clear preferences cache after update
cache.get_cached_user_preferences.cache_delete(user_id)
return output
@@ -486,7 +476,8 @@ async def upload_file(
async def get_user_credits(
user_id: Annotated[str, Security(get_user_id)],
) -> dict[str, int]:
return {"credits": await _user_credit_model.get_credits(user_id)}
user_credit_model = await get_user_credit_model(user_id)
return {"credits": await user_credit_model.get_credits(user_id)}
@v1_router.post(
@@ -498,9 +489,8 @@ async def get_user_credits(
async def request_top_up(
request: RequestTopUp, user_id: Annotated[str, Security(get_user_id)]
):
checkout_url = await _user_credit_model.top_up_intent(
user_id, request.credit_amount
)
user_credit_model = await get_user_credit_model(user_id)
checkout_url = await user_credit_model.top_up_intent(user_id, request.credit_amount)
return {"checkout_url": checkout_url}
@@ -515,7 +505,8 @@ async def refund_top_up(
transaction_key: str,
metadata: dict[str, str],
) -> int:
return await _user_credit_model.top_up_refund(user_id, transaction_key, metadata)
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.top_up_refund(user_id, transaction_key, metadata)
@v1_router.patch(
@@ -525,7 +516,8 @@ async def refund_top_up(
dependencies=[Security(requires_user)],
)
async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
await _user_credit_model.fulfill_checkout(user_id=user_id)
user_credit_model = await get_user_credit_model(user_id)
await user_credit_model.fulfill_checkout(user_id=user_id)
return Response(status_code=200)
@@ -539,18 +531,23 @@ async def configure_user_auto_top_up(
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
) -> str:
if request.threshold < 0:
raise ValueError("Threshold must be greater than 0")
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
if request.amount < 500 and request.amount != 0:
raise ValueError("Amount must be greater than or equal to 500")
if request.amount < request.threshold:
raise ValueError("Amount must be greater than or equal to threshold")
raise HTTPException(
status_code=422, detail="Amount must be greater than or equal to 500"
)
if request.amount != 0 and request.amount < request.threshold:
raise HTTPException(
status_code=422, detail="Amount must be greater than or equal to threshold"
)
current_balance = await _user_credit_model.get_credits(user_id)
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
if current_balance < request.threshold:
await _user_credit_model.top_up_credits(user_id, request.amount)
await user_credit_model.top_up_credits(user_id, request.amount)
else:
await _user_credit_model.top_up_credits(user_id, 0)
await user_credit_model.top_up_credits(user_id, 0)
await set_auto_top_up(
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
@@ -598,15 +595,13 @@ async def stripe_webhook(request: Request):
event["type"] == "checkout.session.completed"
or event["type"] == "checkout.session.async_payment_succeeded"
):
await _user_credit_model.fulfill_checkout(
session_id=event["data"]["object"]["id"]
)
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
if event["type"] == "charge.dispute.created":
await _user_credit_model.handle_dispute(event["data"]["object"])
await UserCredit().handle_dispute(event["data"]["object"])
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
await _user_credit_model.deduct_credits(event["data"]["object"])
await UserCredit().deduct_credits(event["data"]["object"])
return Response(status_code=200)
@@ -620,7 +615,8 @@ async def stripe_webhook(request: Request):
async def manage_payment_method(
user_id: Annotated[str, Security(get_user_id)],
) -> dict[str, str]:
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
user_credit_model = await get_user_credit_model(user_id)
return {"url": await user_credit_model.create_billing_portal_session(user_id)}
@v1_router.get(
@@ -638,7 +634,8 @@ async def get_credit_history(
if transaction_count_limit < 1 or transaction_count_limit > 1000:
raise ValueError("Transaction count limit must be between 1 and 1000")
return await _user_credit_model.get_transaction_history(
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.get_transaction_history(
user_id=user_id,
transaction_time_ceiling=transaction_time,
transaction_count_limit=transaction_count_limit,
@@ -655,7 +652,8 @@ async def get_credit_history(
async def get_refund_requests(
user_id: Annotated[str, Security(get_user_id)],
) -> list[RefundRequest]:
return await _user_credit_model.get_refund_requests(user_id)
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.get_refund_requests(user_id)
########################################################
@@ -676,10 +674,11 @@ class DeleteGraphResponse(TypedDict):
async def list_graphs(
user_id: Annotated[str, Security(get_user_id)],
) -> Sequence[graph_db.GraphMeta]:
paginated_result = await cache.get_cached_graphs(
paginated_result = await graph_db.list_graphs_paginated(
user_id=user_id,
page=1,
page_size=250,
filter_by="active",
)
return paginated_result.graphs
@@ -702,26 +701,13 @@ async def get_graph(
version: int | None = None,
for_export: bool = False,
) -> graph_db.GraphModel:
# Use cache for non-export requests
if not for_export:
graph = await cache.get_cached_graph(
graph_id=graph_id,
version=version,
user_id=user_id,
)
# If graph not found, clear cache entry as permissions may have changed
if not graph:
cache.get_cached_graph.cache_delete(
graph_id=graph_id, version=version, user_id=user_id
)
else:
graph = await graph_db.get_graph(
graph_id,
version,
user_id=user_id,
for_export=for_export,
include_subgraphs=True, # needed to construct full credentials input schema
)
graph = await graph_db.get_graph(
graph_id,
version,
user_id=user_id,
for_export=for_export,
include_subgraphs=True, # needed to construct full credentials input schema
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graph
@@ -736,7 +722,7 @@ async def get_graph(
async def get_graph_all_versions(
graph_id: str, user_id: Annotated[str, Security(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
graphs = await cache.get_cached_graph_all_versions(graph_id, user_id=user_id)
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not graphs:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graphs
@@ -760,26 +746,6 @@ async def create_new_graph(
# as the graph already valid and no sub-graphs are returned back.
await graph_db.create_graph(graph, user_id=user_id)
await library_db.create_library_agent(graph, user_id=user_id)
# Clear graphs list cache after creating new graph
cache.get_cached_graphs.cache_delete(
user_id=user_id,
page=1,
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
)
for page in range(1, cache_config.MAX_PAGES_TO_CLEAR):
library_cache.get_cached_library_agents.cache_delete(
user_id=user_id,
page=page,
page_size=cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE,
)
# Clear my agents cache so user sees new agent immediately
import backend.server.v2.store.cache
backend.server.v2.store.cache._clear_my_agents_cache(user_id)
return await on_graph_activate(graph, user_id=user_id)
@@ -795,32 +761,7 @@ async def delete_graph(
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
await on_graph_deactivate(active_version, user_id=user_id)
result = DeleteGraphResponse(
version_counts=await graph_db.delete_graph(graph_id, user_id=user_id)
)
# Clear caches after deleting graph
cache.get_cached_graphs.cache_delete(
user_id=user_id,
page=1,
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
)
cache.get_cached_graph.cache_delete(
graph_id=graph_id, version=None, user_id=user_id
)
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
# Clear my agents cache so user sees agent removed immediately
import backend.server.v2.store.cache
backend.server.v2.store.cache._clear_my_agents_cache(user_id)
# Clear library agent by graph_id cache
library_cache.get_cached_library_agent_by_graph_id.cache_delete(
graph_id=graph_id, user_id=user_id
)
return result
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
@v1_router.put(
@@ -876,18 +817,6 @@ async def update_graph(
include_subgraphs=True,
)
assert new_graph_version_with_subgraphs # make type checker happy
# Clear caches after updating graph
cache.get_cached_graph.cache_delete(
graph_id=graph_id, version=None, user_id=user_id
)
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
cache.get_cached_graphs.cache_delete(
user_id=user_id,
page=1,
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
)
return new_graph_version_with_subgraphs
@@ -946,36 +875,14 @@ async def execute_graph(
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
) -> execution_db.GraphExecutionMeta:
current_balance = await _user_credit_model.get_credits(user_id)
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
if current_balance <= 0:
raise HTTPException(
status_code=402,
detail="Insufficient balance to execute the agent. Please top up your account.",
)
# Invalidate caches before execution starts so frontend sees fresh data
cache.get_cached_graphs_executions.cache_delete(
user_id=user_id,
page=1,
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
)
for page in range(1, cache_config.MAX_PAGES_TO_CLEAR):
cache.get_cached_graph_execution.cache_delete(
graph_id=graph_id, user_id=user_id, version=graph_version
)
cache.get_cached_graph_executions.cache_delete(
graph_id=graph_id,
user_id=user_id,
page=page,
page_size=cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE,
)
library_cache.get_cached_library_agents.cache_delete(
user_id=user_id,
page=page,
page_size=cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE,
)
try:
result = await execution_utils.add_graph_execution(
graph_id=graph_id,
@@ -988,7 +895,6 @@ async def execute_graph(
# Record successful graph execution
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
record_graph_operation(operation="execute", status="success")
return result
except GraphValidationError as e:
# Record failed graph execution
@@ -1064,7 +970,7 @@ async def _stop_graph_run(
async def list_graphs_executions(
user_id: Annotated[str, Security(get_user_id)],
) -> list[execution_db.GraphExecutionMeta]:
paginated_result = await cache.get_cached_graphs_executions(
paginated_result = await execution_db.get_graph_executions_paginated(
user_id=user_id,
page=1,
page_size=250,
@@ -1086,7 +992,7 @@ async def list_graph_executions(
25, ge=1, le=100, description="Number of executions per page"
),
) -> execution_db.GraphExecutionsPaginated:
return await cache.get_cached_graph_executions(
return await execution_db.get_graph_executions_paginated(
graph_id=graph_id,
user_id=user_id,
page=page,

View File

@@ -23,10 +23,13 @@ client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
def setup_app_auth(mock_jwt_user, setup_test_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
# setup_test_user fixture already executed and user is created in database
# It returns the user_id which we don't need to await
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
@@ -102,13 +105,13 @@ def test_get_graph_blocks(
mock_block.id = "test-block"
mock_block.disabled = False
# Mock get_blocks where it's imported at the top of v1.py
# Mock get_blocks
mocker.patch(
"backend.server.routers.v1.get_blocks",
return_value={"test-block": lambda: mock_block},
)
# Mock block costs where it's imported inside the function
# Mock block costs
mocker.patch(
"backend.data.credit.get_block_cost",
return_value=[{"cost": 10, "type": "credit"}],
@@ -194,8 +197,12 @@ def test_get_user_credits(
snapshot: Snapshot,
) -> None:
"""Test get user credits endpoint"""
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
mock_credit_model = Mock()
mock_credit_model.get_credits = AsyncMock(return_value=1000)
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
response = client.get("/credits")
@@ -215,10 +222,14 @@ def test_request_top_up(
snapshot: Snapshot,
) -> None:
"""Test request top up endpoint"""
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
mock_credit_model = Mock()
mock_credit_model.top_up_intent = AsyncMock(
return_value="https://checkout.example.com/session123"
)
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
request_data = {"credit_amount": 500}
@@ -261,6 +272,74 @@ def test_get_auto_top_up(
)
def test_configure_auto_top_up(
mocker: pytest_mock.MockFixture,
snapshot: Snapshot,
) -> None:
"""Test configure auto top-up endpoint - this test would have caught the enum casting bug"""
# Mock the set_auto_top_up function to avoid database operations
mocker.patch(
"backend.server.routers.v1.set_auto_top_up",
return_value=None,
)
# Mock credit model to avoid Stripe API calls
mock_credit_model = mocker.AsyncMock()
mock_credit_model.get_credits.return_value = 50 # Current balance below threshold
mock_credit_model.top_up_credits.return_value = None
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
# Test data
request_data = {
"threshold": 100,
"amount": 500,
}
response = client.post("/credits/auto-top-up", json=request_data)
# This should succeed with our fix, but would have failed before with the enum casting error
assert response.status_code == 200
assert response.json() == "Auto top-up settings updated"
def test_configure_auto_top_up_validation_errors(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test configure auto top-up endpoint validation"""
# Mock set_auto_top_up to avoid database operations for successful case
mocker.patch("backend.server.routers.v1.set_auto_top_up")
# Mock credit model to avoid Stripe API calls for the successful case
mock_credit_model = mocker.AsyncMock()
mock_credit_model.get_credits.return_value = 50
mock_credit_model.top_up_credits.return_value = None
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
# Test negative threshold
response = client.post(
"/credits/auto-top-up", json={"threshold": -1, "amount": 500}
)
assert response.status_code == 422 # Validation error
# Test amount too small (but not 0)
response = client.post(
"/credits/auto-top-up", json={"threshold": 100, "amount": 100}
)
assert response.status_code == 422 # Validation error
# Test amount = 0 (should be allowed)
response = client.post("/credits/auto-top-up", json={"threshold": 100, "amount": 0})
assert response.status_code == 200 # Should succeed
# Graphs endpoints tests
def test_get_graphs(
mocker: pytest_mock.MockFixture,

View File

@@ -1,299 +0,0 @@
#!/usr/bin/env python3
"""
Complete audit of all @cached functions to verify proper cache invalidation.
This test systematically checks every @cached function in the codebase
to ensure it has appropriate cache invalidation logic when data changes.
"""
import pytest
class TestCacheInvalidationAudit:
"""Audit all @cached functions for proper invalidation."""
def test_v1_router_caches(self):
"""
V1 Router cached functions:
- _get_cached_blocks(): ✓ NEVER CHANGES (blocks are static in code)
"""
# No invalidation needed for static data
pass
def test_v1_cache_module_graph_caches(self):
"""
V1 Cache module graph-related caches:
- get_cached_graphs(user_id, page, page_size): ✓ HAS INVALIDATION
Cleared in: v1.py create_graph(), delete_graph(), update_graph_metadata(), stop_graph_execution()
- get_cached_graph(graph_id, version, user_id): ✓ HAS INVALIDATION
Cleared in: v1.py delete_graph(), update_graph(), delete_graph_execution()
- get_cached_graph_all_versions(graph_id, user_id): ✓ HAS INVALIDATION
Cleared in: v1.py delete_graph(), update_graph(), delete_graph_execution()
- get_cached_graph_executions(graph_id, user_id, page, page_size): ✓ HAS INVALIDATION
Cleared in: v1.py stop_graph_execution()
Also cleared in: v2/library/routes/presets.py
- get_cached_graphs_executions(user_id, page, page_size): ✓ HAS INVALIDATION
Cleared in: v1.py stop_graph_execution()
- get_cached_graph_execution(graph_exec_id, user_id): ✓ HAS INVALIDATION
Cleared in: v1.py stop_graph_execution()
ISSUE: All use hardcoded page_size values instead of cache_config constants!
"""
# Document that v1 routes should migrate to use cache_config
pass
def test_v1_cache_module_user_caches(self):
"""
V1 Cache module user-related caches:
- get_cached_user_timezone(user_id): ✓ HAS INVALIDATION
Cleared in: v1.py update_user_profile()
- get_cached_user_preferences(user_id): ✓ HAS INVALIDATION
Cleared in: v1.py update_user_notification_preferences()
"""
pass
def test_v2_store_cache_functions(self):
"""
V2 Store cached functions:
- _get_cached_user_profile(user_id): ✓ HAS INVALIDATION
Cleared in: v2/store/routes.py update_or_create_profile()
- _get_cached_store_agents(...): ⚠️ PARTIAL INVALIDATION
Cleared in: v2/admin/store_admin_routes.py review_submission() - uses cache_clear()
NOT cleared when agents are created/updated!
- _get_cached_agent_details(username, agent_name): ❌ NO INVALIDATION
NEVER cleared! Relies only on TTL (15 min)
- _get_cached_agent_graph(store_listing_version_id): ❌ NO INVALIDATION
NEVER cleared! Relies only on TTL (1 hour)
- _get_cached_store_agent_by_version(store_listing_version_id): ❌ NO INVALIDATION
NEVER cleared! Relies only on TTL (1 hour)
- _get_cached_store_creators(...): ❌ NO INVALIDATION
NEVER cleared! Relies only on TTL (1 hour)
- _get_cached_creator_details(username): ❌ NO INVALIDATION
NEVER cleared! Relies only on TTL (1 hour)
- _get_cached_my_agents(user_id, page, page_size): ❌ NO INVALIDATION
NEVER cleared! Users won't see new agents for 5 minutes!
CRITICAL BUG: Should be cleared when user creates/deletes agents
- _get_cached_submissions(user_id, page, page_size): ✓ HAS INVALIDATION
Cleared via: _clear_submissions_cache() helper
Called in: create_submission(), edit_submission(), delete_submission()
Called in: v2/admin/store_admin_routes.py review_submission()
"""
# Document critical issues
CRITICAL_MISSING_INVALIDATION = [
"_get_cached_my_agents - users won't see new agents immediately",
]
# Acceptable TTL-only caches (documented, not asserted):
# - _get_cached_agent_details (public data, 15min TTL acceptable)
# - _get_cached_agent_graph (immutable data, 1hr TTL acceptable)
# - _get_cached_store_agent_by_version (immutable version, 1hr TTL acceptable)
# - _get_cached_store_creators (public data, 1hr TTL acceptable)
# - _get_cached_creator_details (public data, 1hr TTL acceptable)
assert (
len(CRITICAL_MISSING_INVALIDATION) == 1
), "These caches need invalidation logic:\n" + "\n".join(
CRITICAL_MISSING_INVALIDATION
)
def test_v2_library_cache_functions(self):
"""
V2 Library cached functions:
- get_cached_library_agents(user_id, page, page_size, ...): ✓ HAS INVALIDATION
Cleared in: v1.py create_graph(), stop_graph_execution()
Cleared in: v2/library/routes/agents.py add_library_agent(), remove_library_agent()
- get_cached_library_agent_favorites(user_id, page, page_size): ✓ HAS INVALIDATION
Cleared in: v2/library/routes/agents.py favorite/unfavorite endpoints
- get_cached_library_agent(library_agent_id, user_id): ✓ HAS INVALIDATION
Cleared in: v2/library/routes/agents.py remove_library_agent()
- get_cached_library_agent_by_graph_id(graph_id, user_id): ❌ NO INVALIDATION
NEVER cleared! Relies only on TTL (30 min)
Should be cleared when graph is deleted
- get_cached_library_agent_by_store_version(store_listing_version_id, user_id): ❌ NO INVALIDATION
NEVER cleared! Relies only on TTL (1 hour)
Probably acceptable as store versions are immutable
- get_cached_library_presets(user_id, page, page_size): ✓ HAS INVALIDATION
Cleared via: _clear_presets_list_cache() helper
Called in: v2/library/routes/presets.py preset mutations
- get_cached_library_preset(preset_id, user_id): ✓ HAS INVALIDATION
Cleared in: v2/library/routes/presets.py preset mutations
ISSUE: Clearing uses hardcoded page_size values (10 and 20) instead of cache_config!
"""
pass
def test_immutable_singleton_caches(self):
"""
Caches that never need invalidation (singleton or immutable):
- get_webhook_block_ids(): ✓ STATIC (blocks in code)
- get_io_block_ids(): ✓ STATIC (blocks in code)
- get_supabase(): ✓ CLIENT INSTANCE (no invalidation needed)
- get_async_supabase(): ✓ CLIENT INSTANCE (no invalidation needed)
- _get_all_providers(): ✓ STATIC CONFIG (providers in code)
- get_redis(): ✓ CLIENT INSTANCE (no invalidation needed)
- load_webhook_managers(): ✓ STATIC (managers in code)
- load_all_blocks(): ✓ STATIC (blocks in code)
- get_cached_blocks(): ✓ STATIC (blocks in code)
"""
pass
def test_feature_flag_cache(self):
"""
Feature flag cache:
- _fetch_user_context_data(user_id): ⚠️ LONG TTL
TTL: 24 hours
NO INVALIDATION
This is probably acceptable as user context changes infrequently.
However, if user metadata changes, they won't see updated flags for 24 hours.
"""
pass
def test_onboarding_cache(self):
"""
Onboarding cache:
- onboarding_enabled(): ⚠️ NO INVALIDATION
TTL: 5 minutes
NO INVALIDATION
Should probably be cleared when store agents are added/removed.
But 5min TTL is acceptable for this use case.
"""
pass
class TestCacheInvalidationPageSizeConsistency:
"""Test that all cache_delete calls use consistent page_size values."""
def test_v1_routes_hardcoded_page_sizes(self):
"""
V1 routes use hardcoded page_size values that should migrate to cache_config:
❌ page_size=250 for graphs:
- v1.py line 765: cache.get_cached_graphs.cache_delete(user_id, page=1, page_size=250)
- v1.py line 791: cache.get_cached_graphs.cache_delete(user_id, page=1, page_size=250)
- v1.py line 859: cache.get_cached_graphs.cache_delete(user_id, page=1, page_size=250)
- v1.py line 929: cache.get_cached_graphs_executions.cache_delete(user_id, page=1, page_size=250)
❌ page_size=10 for library agents:
- v1.py line 768: library_cache.get_cached_library_agents.cache_delete(..., page_size=10)
- v1.py line 940: library_cache.get_cached_library_agents.cache_delete(..., page_size=10)
❌ page_size=25 for graph executions:
- v1.py line 937: cache.get_cached_graph_executions.cache_delete(..., page_size=25)
RECOMMENDATION: Create constants in cache_config and migrate v1 routes to use them.
"""
from backend.server import cache_config
# These constants exist but aren't used in v1 routes yet
assert cache_config.V1_GRAPHS_PAGE_SIZE == 250
assert cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE == 10
assert cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE == 25
def test_v2_library_routes_hardcoded_page_sizes(self):
"""
V2 library routes use hardcoded page_size values:
❌ v2/library/routes/agents.py:
- line 233: cache_delete(..., page_size=10)
❌ v2/library/routes/presets.py _clear_presets_list_cache():
- Clears BOTH page_size=10 AND page_size=20
- This suggests different consumers use different page sizes
❌ v2/library/routes/presets.py:
- line 449: cache_delete(..., page_size=10)
- line 452: cache_delete(..., page_size=25)
RECOMMENDATION: Migrate to use cache_config constants.
"""
from backend.server import cache_config
# Constants exist for library
assert cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE == 10
assert cache_config.V2_LIBRARY_PRESETS_PAGE_SIZE == 20
assert cache_config.V2_LIBRARY_PRESETS_ALT_PAGE_SIZE == 10
def test_only_page_1_cleared_risk(self):
"""
Document cache_delete calls that only clear page=1.
RISKY PATTERN: Many cache_delete calls only clear page=1:
- v1.py create_graph(): Only clears page=1 of graphs
- v1.py delete_graph(): Only clears page=1 of graphs
- v1.py update_graph_metadata(): Only clears page=1 of graphs
- v1.py stop_graph_execution(): Only clears page=1 of executions
PROBLEM: If user has > 1 page, subsequent pages show stale data until TTL expires.
SOLUTIONS:
1. Use cache_clear() to clear all pages (nuclear option)
2. Loop through multiple pages like _clear_submissions_cache does
3. Accept TTL-based expiry for pages 2+ (current approach)
Current approach is probably acceptable given TTL values are reasonable.
"""
pass
class TestCriticalCacheBugs:
"""Document critical cache bugs that need fixing."""
def test_my_agents_cache_never_cleared(self):
"""
CRITICAL BUG: _get_cached_my_agents is NEVER cleared!
Impact:
- User creates a new agent → Won't see it in "My Agents" for 5 minutes
- User deletes an agent → Still see it in "My Agents" for 5 minutes
Fix needed:
1. Create _clear_my_agents_cache() helper (like _clear_submissions_cache)
2. Call it from v1.py create_graph() and delete_graph()
3. Use cache_config.V2_MY_AGENTS_PAGE_SIZE constant
Location: v2/store/cache.py line 120
"""
# This documents the bug
NEEDS_CACHE_CLEARING = "_get_cached_my_agents"
assert NEEDS_CACHE_CLEARING == "_get_cached_my_agents"
def test_library_agent_by_graph_id_never_cleared(self):
"""
BUG: get_cached_library_agent_by_graph_id is NEVER cleared!
Impact:
- User deletes a graph → Library still shows it's available for 30 minutes
Fix needed:
- Clear in v1.py delete_graph()
- Clear in v2/library/routes/agents.py remove_library_agent()
Location: v2/library/cache.py line 59
"""
pass
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,95 +0,0 @@
#!/usr/bin/env python3
"""
Test suite to verify cache_config constants are being used correctly.
This ensures that the centralized cache_config.py constants are actually
used throughout the codebase, not just defined.
"""
import pytest
from backend.server import cache_config
class TestCacheConfigConstants:
"""Verify cache_config constants have expected values."""
def test_v2_store_page_sizes(self):
"""Test V2 Store API page size constants."""
assert cache_config.V2_STORE_AGENTS_PAGE_SIZE == 20
assert cache_config.V2_STORE_CREATORS_PAGE_SIZE == 20
assert cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE == 20
assert cache_config.V2_MY_AGENTS_PAGE_SIZE == 20
def test_v2_library_page_sizes(self):
"""Test V2 Library API page size constants."""
assert cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE == 10
assert cache_config.V2_LIBRARY_PRESETS_PAGE_SIZE == 20
assert cache_config.V2_LIBRARY_PRESETS_ALT_PAGE_SIZE == 10
def test_v1_page_sizes(self):
"""Test V1 API page size constants."""
assert cache_config.V1_GRAPHS_PAGE_SIZE == 250
assert cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE == 10
assert cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE == 25
def test_cache_clearing_config(self):
"""Test cache clearing configuration."""
assert cache_config.MAX_PAGES_TO_CLEAR == 20
def test_get_page_sizes_for_clearing_helper(self):
"""Test the helper function for getting page sizes to clear."""
# Single page size
result = cache_config.get_page_sizes_for_clearing(20)
assert result == [20]
# Multiple page sizes
result = cache_config.get_page_sizes_for_clearing(20, 10)
assert result == [20, 10]
# With None alt_page_size
result = cache_config.get_page_sizes_for_clearing(20, None)
assert result == [20]
class TestCacheConfigUsage:
"""Test that cache_config constants are actually used in the code."""
def test_store_routes_import_cache_config(self):
"""Verify store routes imports cache_config."""
import backend.server.v2.store.routes as store_routes
# Check that cache_config is imported
assert hasattr(store_routes, "backend")
assert hasattr(store_routes.backend.server, "cache_config")
def test_store_cache_uses_constants(self):
"""Verify store cache module uses cache_config constants."""
import backend.server.v2.store.cache as store_cache
# Check the module imports cache_config
assert hasattr(store_cache, "backend")
assert hasattr(store_cache.backend.server, "cache_config")
# The _clear_submissions_cache function should use the constant
import inspect
source = inspect.getsource(store_cache._clear_submissions_cache)
assert (
"cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE" in source
), "_clear_submissions_cache must use cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE"
assert (
"cache_config.MAX_PAGES_TO_CLEAR" in source
), "_clear_submissions_cache must use cache_config.MAX_PAGES_TO_CLEAR"
def test_admin_routes_use_constants(self):
"""Verify admin routes use cache_config constants."""
import backend.server.v2.admin.store_admin_routes as admin_routes
# Check that cache_config is imported
assert hasattr(admin_routes, "backend")
assert hasattr(admin_routes.backend.server, "cache_config")
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,263 +0,0 @@
#!/usr/bin/env python3
"""
Comprehensive test suite for cache invalidation consistency across the entire backend.
This test file identifies ALL locations where cache_delete is called with hardcoded
parameters (especially page_size) and ensures they match the corresponding route defaults.
CRITICAL: If any test in this file fails, it means cache invalidation will be broken
and users will see stale data after mutations.
Key problem areas identified:
1. v1.py routes: Uses page_size=250 for graphs, but cache clearing uses page_size=250 ✓
2. v1.py routes: Uses page_size=10 for library agents clearing
3. v2/library routes: Uses page_size=10 for library agents clearing
4. v2/store routes: Uses page_size=20 for submissions clearing (in _clear_submissions_cache)
5. v2/library presets: Uses page_size=10 AND page_size=20 for presets (dual clearing)
"""
import pytest
class TestCacheInvalidationConsistency:
"""Test that all cache_delete calls use correct parameters matching route defaults."""
def test_v1_graphs_cache_page_size_consistency(self):
"""
Test v1 graphs routes use consistent page_size.
Locations that must match:
- routes/v1.py line 682: default page_size=250
- routes/v1.py line 765: cache_delete with page_size=250
- routes/v1.py line 791: cache_delete with page_size=250
- routes/v1.py line 859: cache_delete with page_size=250
- routes/v1.py line 929: cache_delete with page_size=250
- routes/v1.py line 1034: default page_size=250
"""
V1_GRAPHS_DEFAULT_PAGE_SIZE = 250
# This is the expected value - if this test fails, check all the above locations
assert V1_GRAPHS_DEFAULT_PAGE_SIZE == 250, (
"If you changed the default page_size for v1 graphs, you must update:\n"
"1. routes/v1.py list_graphs() default parameter\n"
"2. routes/v1.py create_graph() cache_delete call\n"
"3. routes/v1.py delete_graph() cache_delete call\n"
"4. routes/v1.py update_graph_metadata() cache_delete call\n"
"5. routes/v1.py stop_graph_execution() cache_delete call\n"
"6. routes/v1.py list_graph_run_events() default parameter"
)
def test_v1_library_agents_cache_page_size_consistency(self):
"""
Test v1 library agents cache clearing uses consistent page_size.
Locations that must match:
- routes/v1.py line 768: cache_delete with page_size=10
- routes/v1.py line 940: cache_delete with page_size=10
- v2/library/routes/agents.py line 233: cache_delete with page_size=10
WARNING: These hardcode page_size=10 but we need to verify this matches
the actual page_size used when fetching library agents!
"""
V1_LIBRARY_AGENTS_CLEARING_PAGE_SIZE = 10
assert V1_LIBRARY_AGENTS_CLEARING_PAGE_SIZE == 10, (
"If you changed the library agents clearing page_size, you must update:\n"
"1. routes/v1.py create_graph() cache clearing loop\n"
"2. routes/v1.py stop_graph_execution() cache clearing loop\n"
"3. v2/library/routes/agents.py add_library_agent() cache clearing loop"
)
# TODO: This should be verified against the actual default used in library routes
def test_v1_graph_executions_cache_page_size_consistency(self):
"""
Test v1 graph executions cache clearing uses consistent page_size.
Locations:
- routes/v1.py line 937: cache_delete with page_size=25
- v2/library/routes/presets.py line 449: cache_delete with page_size=10
- v2/library/routes/presets.py line 452: cache_delete with page_size=25
"""
V1_GRAPH_EXECUTIONS_CLEARING_PAGE_SIZE = 25
# Note: presets.py clears BOTH page_size=10 AND page_size=25
# This suggests there may be multiple consumers with different page sizes
assert V1_GRAPH_EXECUTIONS_CLEARING_PAGE_SIZE == 25
def test_v2_store_submissions_cache_page_size_consistency(self):
"""
Test v2 store submissions use consistent page_size.
Locations that must match:
- v2/store/routes.py line 484: default page_size=20
- v2/store/cache.py line 18: _clear_submissions_cache uses page_size=20
This is already tested in test_cache_delete.py but documented here for completeness.
"""
V2_STORE_SUBMISSIONS_DEFAULT_PAGE_SIZE = 20
V2_STORE_SUBMISSIONS_CLEARING_PAGE_SIZE = 20
assert (
V2_STORE_SUBMISSIONS_DEFAULT_PAGE_SIZE
== V2_STORE_SUBMISSIONS_CLEARING_PAGE_SIZE
), (
"The default page_size for store submissions must match the hardcoded value in _clear_submissions_cache!\n"
"Update both:\n"
"1. v2/store/routes.py get_submissions() default parameter\n"
"2. v2/store/cache.py _clear_submissions_cache() hardcoded page_size"
)
def test_v2_library_presets_cache_page_size_consistency(self):
"""
Test v2 library presets cache clearing uses consistent page_size.
Locations:
- v2/library/routes/presets.py line 36: cache_delete with page_size=10
- v2/library/routes/presets.py line 39: cache_delete with page_size=20
This route clears BOTH page_size=10 and page_size=20, suggesting multiple consumers.
"""
V2_LIBRARY_PRESETS_CLEARING_PAGE_SIZES = [10, 20]
assert 10 in V2_LIBRARY_PRESETS_CLEARING_PAGE_SIZES
assert 20 in V2_LIBRARY_PRESETS_CLEARING_PAGE_SIZES
# TODO: Verify these match the actual page_size defaults used in preset routes
def test_cache_clearing_helper_functions_documented(self):
"""
Document all cache clearing helper functions and their hardcoded parameters.
Helper functions that wrap cache_delete with hardcoded params:
1. v2/store/cache.py::_clear_submissions_cache() - hardcodes page_size=20, num_pages=20
2. v2/library/routes/presets.py::_clear_presets_list_cache() - hardcodes page_size=10 AND 20, num_pages=20
These helpers are DANGEROUS because:
- They hide the hardcoded parameters
- They loop through multiple pages with hardcoded page_size
- If the route default changes, these won't clear the right cache entries
"""
HELPER_FUNCTIONS = {
"_clear_submissions_cache": {
"file": "v2/store/cache.py",
"page_size": 20,
"num_pages": 20,
"risk": "HIGH - single page_size, could miss entries if default changes",
},
"_clear_presets_list_cache": {
"file": "v2/library/routes/presets.py",
"page_size": [10, 20],
"num_pages": 20,
"risk": "MEDIUM - clears multiple page_sizes, but could still miss new ones",
},
}
assert (
len(HELPER_FUNCTIONS) == 2
), "If you add new cache clearing helper functions, document them here!"
def test_cache_delete_without_page_loops_are_risky(self):
"""
Document cache_delete calls that clear only page=1 (risky if there are multiple pages).
Single page cache_delete calls:
- routes/v1.py line 765: Only clears page=1 with page_size=250
- routes/v1.py line 791: Only clears page=1 with page_size=250
- routes/v1.py line 859: Only clears page=1 with page_size=250
These are RISKY because:
- If a user has more than one page of graphs, pages 2+ won't be invalidated
- User could see stale data on pagination
RECOMMENDATION: Use cache_clear() or loop through multiple pages like
_clear_submissions_cache does.
"""
SINGLE_PAGE_CLEARS = [
"routes/v1.py line 765: create_graph clears only page=1",
"routes/v1.py line 791: delete_graph clears only page=1",
"routes/v1.py line 859: update_graph_metadata clears only page=1",
]
# This test documents the issue but doesn't fail
# Consider this a TODO to fix these cache clearing strategies
assert (
len(SINGLE_PAGE_CLEARS) >= 3
), "These cache_delete calls should probably loop through multiple pages"
def test_all_cached_functions_have_proper_invalidation(self):
"""
Verify all @cached functions have corresponding cache_delete calls.
Functions with proper invalidation:
✓ get_cached_user_profile - cleared on profile update
✓ get_cached_store_agents - cleared on admin review (cache_clear)
✓ get_cached_submissions - cleared via _clear_submissions_cache helper
✓ get_cached_graphs - cleared on graph mutations
✓ get_cached_library_agents - cleared on library changes
Functions that might not have proper invalidation:
? get_cached_agent_details - not explicitly cleared
? get_cached_store_creators - not explicitly cleared
? get_cached_my_agents - not explicitly cleared (no helper function exists!)
This is a documentation test - actual verification requires code analysis.
"""
NEEDS_VERIFICATION = [
"get_cached_agent_details",
"get_cached_store_creators",
"get_cached_my_agents", # NO CLEARING FUNCTION EXISTS!
]
assert "get_cached_my_agents" in NEEDS_VERIFICATION, (
"get_cached_my_agents has no cache clearing logic - this is a BUG!\n"
"When a user creates/deletes an agent, their 'my agents' list won't update."
)
class TestCacheKeyParameterOrdering:
"""
Test that cache_delete calls use the same parameter order as the @cached function.
The @cached decorator uses function signature order to create cache keys.
cache_delete must use the exact same order or it won't find the cached entry!
"""
def test_cached_function_parameter_order_matters(self):
"""
Document that parameter order in cache_delete must match @cached function signature.
Example from v2/store/cache.py:
@cached(...)
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
...
CORRECT: _get_cached_submissions.cache_delete(user_id, page=1, page_size=20)
WRONG: _get_cached_submissions.cache_delete(page=1, user_id=user_id, page_size=20)
The cached decorator generates keys based on the POSITIONAL order, so parameter
order must match between the function definition and cache_delete call.
"""
# This is a documentation test - no assertion needed
# Real verification requires inspecting each cache_delete call
pass
def test_named_parameters_vs_positional_in_cache_delete(self):
"""
Document best practice: use named parameters in cache_delete for safety.
Good practice seen in codebase:
- cache.get_cached_graphs.cache_delete(user_id=user_id, page=1, page_size=250)
- library_cache.get_cached_library_agents.cache_delete(user_id=user_id, page=page, page_size=10)
This is safer than positional arguments because:
1. More readable
2. Less likely to get order wrong
3. Self-documenting what each parameter means
"""
pass
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -457,8 +457,7 @@ async def test_api_key_with_unicode_characters_normalization_attack(mock_request
"""Test that Unicode normalization doesn't bypass validation."""
# Create auth with composed Unicode character
auth = APIKeyAuthenticator(
header_name="X-API-Key",
expected_token="café", # é is composed
header_name="X-API-Key", expected_token="café" # é is composed
)
# Try with decomposed version (c + a + f + e + ´)
@@ -523,8 +522,8 @@ async def test_api_keys_with_newline_variations(mock_request):
"valid\r\ntoken", # Windows newline
"valid\rtoken", # Mac newline
"valid\x85token", # NEL (Next Line)
"valid\x0btoken", # Vertical Tab
"valid\x0ctoken", # Form Feed
"valid\x0Btoken", # Vertical Tab
"valid\x0Ctoken", # Form Feed
]
for api_key in newline_variations:

View File

@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
class AutoModManager:
def __init__(self):
self.config = self._load_config()

View File

@@ -11,8 +11,6 @@ from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
router = APIRouter(
prefix="/admin",
@@ -33,7 +31,8 @@ async def add_user_credits(
logger.info(
f"Admin user {admin_user_id} is adding {amount} credits to user {user_id}"
)
new_balance, transaction_key = await _user_credit_model._add_transaction(
user_credit_model = await get_user_credit_model(user_id)
new_balance, transaction_key = await user_credit_model._add_transaction(
user_id,
amount,
transaction_type=CreditTransactionType.GRANT,

View File

@@ -1,5 +1,5 @@
import json
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, Mock
import fastapi
import fastapi.testclient
@@ -7,12 +7,12 @@ import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from prisma import Json
from pytest_snapshot.plugin import Snapshot
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
import backend.server.v2.admin.model as admin_model
from backend.data.model import UserTransaction
from backend.util.json import SafeJson
from backend.util.models import Pagination
app = fastapi.FastAPI()
@@ -37,12 +37,14 @@ def test_add_user_credits_success(
) -> None:
"""Test successful credit addition by admin"""
# Mock the credit model
mock_credit_model = mocker.patch(
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
)
mock_credit_model = Mock()
mock_credit_model._add_transaction = AsyncMock(
return_value=(1500, "transaction-123-uuid")
)
mocker.patch(
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
return_value=mock_credit_model,
)
request_data = {
"user_id": target_user_id,
@@ -62,11 +64,17 @@ def test_add_user_credits_success(
call_args = mock_credit_model._add_transaction.call_args
assert call_args[0] == (target_user_id, 500)
assert call_args[1]["transaction_type"] == prisma.enums.CreditTransactionType.GRANT
# Check that metadata is a Json object with the expected content
assert isinstance(call_args[1]["metadata"], Json)
assert call_args[1]["metadata"] == Json(
{"admin_id": admin_user_id, "reason": "Test credit grant for debugging"}
)
# Check that metadata is a SafeJson object with the expected content
assert isinstance(call_args[1]["metadata"], SafeJson)
actual_metadata = call_args[1]["metadata"]
expected_data = {
"admin_id": admin_user_id,
"reason": "Test credit grant for debugging",
}
# SafeJson inherits from Json which stores parsed data in the .data attribute
assert actual_metadata.data["admin_id"] == expected_data["admin_id"]
assert actual_metadata.data["reason"] == expected_data["reason"]
# Snapshot test the response
configured_snapshot.assert_match(
@@ -81,12 +89,14 @@ def test_add_user_credits_negative_amount(
) -> None:
"""Test credit deduction by admin (negative amount)"""
# Mock the credit model
mock_credit_model = mocker.patch(
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
)
mock_credit_model = Mock()
mock_credit_model._add_transaction = AsyncMock(
return_value=(200, "transaction-456-uuid")
)
mocker.patch(
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
return_value=mock_credit_model,
)
request_data = {
"user_id": "target-user-id",

View File

@@ -7,8 +7,7 @@ import fastapi
import fastapi.responses
import prisma.enums
import backend.server.cache_config
import backend.server.v2.store.cache
import backend.server.v2.store.cache as store_cache
import backend.server.v2.store.db
import backend.server.v2.store.model
import backend.util.json
@@ -31,7 +30,7 @@ async def get_admin_listings_with_versions(
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = backend.server.cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE,
page_size: int = 20,
):
"""
Get store listings with their version history for admins.
@@ -88,6 +87,11 @@ async def review_submission(
StoreSubmission with updated review information
"""
try:
already_approved = (
await backend.server.v2.store.db.check_submission_already_approved(
store_listing_version_id=store_listing_version_id,
)
)
submission = await backend.server.v2.store.db.review_store_submission(
store_listing_version_id=store_listing_version_id,
is_approved=request.is_approved,
@@ -95,8 +99,11 @@ async def review_submission(
internal_comments=request.internal_comments or "",
reviewer_id=user_id,
)
backend.server.v2.store.cache._clear_submissions_cache(submission.user_id)
backend.server.v2.store.cache._get_cached_store_agents.cache_clear()
state_changed = already_approved != request.is_approved
# Clear caches when the request is approved as it updates what is shown on the store
if state_changed:
store_cache.clear_all_caches()
return submission
except Exception as e:
logger.exception("Error reviewing submission: %s", e)

View File

@@ -118,6 +118,17 @@ def get_blocks(
)
def get_block_by_id(block_id: str) -> BlockInfo | None:
"""
Get a specific block by its ID.
"""
for block_type in load_all_blocks().values():
block: Block[BlockSchema, BlockSchema] = block_type()
if block.id == block_id:
return block.get_info()
return None
def search_blocks(
include_blocks: bool = True,
include_integrations: bool = True,

View File

@@ -53,16 +53,6 @@ class ProviderResponse(BaseModel):
pagination: Pagination
# Search
class SearchRequest(BaseModel):
search_query: str | None = None
filter: list[FilterType] | None = None
by_creator: list[str] | None = None
search_id: str | None = None
page: int | None = None
page_size: int | None = None
class SearchBlocksResponse(BaseModel):
blocks: BlockResponse
total_block_count: int

View File

@@ -110,6 +110,25 @@ async def get_blocks(
)
@router.get(
"/blocks/batch",
summary="Get specific blocks",
response_model=list[builder_model.BlockInfo],
)
async def get_specific_blocks(
block_ids: Annotated[list[str], fastapi.Query()],
) -> list[builder_model.BlockInfo]:
"""
Get specific blocks by their IDs.
"""
blocks = []
for block_id in block_ids:
block = builder_db.get_block_by_id(block_id)
if block:
blocks.append(block)
return blocks
@router.get(
"/providers",
summary="Get Builder integration providers",
@@ -128,30 +147,34 @@ async def get_providers(
)
@router.post(
# Not using post method because on frontend, orval doesn't support Infinite Query with POST method.
@router.get(
"/search",
summary="Builder search",
tags=["store", "private"],
response_model=builder_model.SearchResponse,
)
async def search(
options: builder_model.SearchRequest,
user_id: Annotated[str, fastapi.Security(get_user_id)],
search_query: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[list[str] | None, fastapi.Query()] = None,
search_id: Annotated[str | None, fastapi.Query()] = None,
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
page_size: Annotated[int, fastapi.Query()] = 50,
) -> builder_model.SearchResponse:
"""
Search for blocks (including integrations), marketplace agents, and user library agents.
"""
# If no filters are provided, then we will return all types
if not options.filter:
options.filter = [
if not filter:
filter = [
"blocks",
"integrations",
"marketplace_agents",
"my_agents",
]
options.search_query = sanitize_query(options.search_query)
options.page = options.page or 1
options.page_size = options.page_size or 50
search_query = sanitize_query(search_query)
# Blocks&Integrations
blocks = builder_model.SearchBlocksResponse(
@@ -162,13 +185,13 @@ async def search(
total_block_count=0,
total_integration_count=0,
)
if "blocks" in options.filter or "integrations" in options.filter:
if "blocks" in filter or "integrations" in filter:
blocks = builder_db.search_blocks(
include_blocks="blocks" in options.filter,
include_integrations="integrations" in options.filter,
query=options.search_query or "",
page=options.page,
page_size=options.page_size,
include_blocks="blocks" in filter,
include_integrations="integrations" in filter,
query=search_query or "",
page=page,
page_size=page_size,
)
# Library Agents
@@ -176,12 +199,12 @@ async def search(
agents=[],
pagination=Pagination.empty(),
)
if "my_agents" in options.filter:
if "my_agents" in filter:
my_agents = await library_db.list_library_agents(
user_id=user_id,
search_term=options.search_query,
page=options.page,
page_size=options.page_size,
search_term=search_query,
page=page,
page_size=page_size,
)
# Marketplace Agents
@@ -189,12 +212,12 @@ async def search(
agents=[],
pagination=Pagination.empty(),
)
if "marketplace_agents" in options.filter:
if "marketplace_agents" in filter:
marketplace_agents = await store_db.get_store_agents(
creators=options.by_creator,
search_query=options.search_query,
page=options.page,
page_size=options.page_size,
creators=by_creator,
search_query=search_query,
page=page,
page_size=page_size,
)
more_pages = False
@@ -214,7 +237,7 @@ async def search(
"marketplace_agents": marketplace_agents.pagination.total_items,
"my_agents": my_agents.pagination.total_items,
},
page=options.page,
page=page,
more_pages=more_pages,
)

View File

@@ -1,111 +0,0 @@
"""
Cache functions for Library API endpoints.
This module contains all caching decorators and helpers for the Library API,
separated from the main routes for better organization and maintainability.
"""
import backend.server.v2.library.db
from backend.util.cache import cached
# ===== Library Agent Caches =====
# Cache library agents list for 10 minutes
@cached(maxsize=1000, ttl_seconds=600, shared_cache=True)
async def get_cached_library_agents(
user_id: str,
page: int = 1,
page_size: int = 20,
):
"""Cached helper to get library agents list."""
return await backend.server.v2.library.db.list_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache user's favorite agents for 5 minutes - favorites change more frequently
@cached(maxsize=500, ttl_seconds=300, shared_cache=True)
async def get_cached_library_agent_favorites(
user_id: str,
page: int = 1,
page_size: int = 20,
):
"""Cached helper to get user's favorite library agents."""
return await backend.server.v2.library.db.list_favorite_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache individual library agent details for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_agent(
library_agent_id: str,
user_id: str,
):
"""Cached helper to get library agent details."""
return await backend.server.v2.library.db.get_library_agent(
id=library_agent_id,
user_id=user_id,
)
# Cache library agent by graph ID for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_agent_by_graph_id(
graph_id: str,
user_id: str,
):
"""Cached helper to get library agent by graph ID."""
return await backend.server.v2.library.db.get_library_agent_by_graph_id(
graph_id=graph_id,
user_id=user_id,
)
# Cache library agent by store version ID for 1 hour - marketplace agents are more stable
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
async def get_cached_library_agent_by_store_version(
store_listing_version_id: str,
user_id: str,
):
"""Cached helper to get library agent by store version ID."""
return await backend.server.v2.library.db.get_library_agent_by_store_version_id(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
# ===== Library Preset Caches =====
# Cache library presets list for 30 minutes
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_presets(
user_id: str,
page: int = 1,
page_size: int = 20,
):
"""Cached helper to get library presets list."""
return await backend.server.v2.library.db.list_presets(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache individual preset details for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_preset(
preset_id: str,
user_id: str,
):
"""Cached helper to get library preset details."""
return await backend.server.v2.library.db.get_preset(
preset_id=preset_id,
user_id=user_id,
)

View File

@@ -1,286 +0,0 @@
"""
Tests for cache invalidation in Library API routes.
This module tests that library caches are properly invalidated when data is modified.
"""
import uuid
from unittest.mock import AsyncMock, patch
import pytest
import backend.server.v2.library.cache as library_cache
import backend.server.v2.library.db as library_db
@pytest.fixture
def mock_user_id():
"""Generate a mock user ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def mock_library_agent_id():
"""Generate a mock library agent ID for testing."""
return str(uuid.uuid4())
class TestLibraryAgentCacheInvalidation:
"""Test cache invalidation for library agent operations."""
@pytest.mark.asyncio
async def test_add_agent_clears_list_cache(self, mock_user_id):
"""Test that adding an agent clears the library agents list cache."""
# Clear cache
library_cache.get_cached_library_agents.cache_clear()
with patch.object(
library_db, "list_library_agents", new_callable=AsyncMock
) as mock_list:
mock_response = {"agents": [], "total_count": 0, "page": 1, "page_size": 20}
mock_list.return_value = mock_response
# First call hits database
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_list.call_count == 1
# Second call uses cache
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_list.call_count == 1 # Still 1, cache used
# Simulate adding an agent (cache invalidation)
for page in range(1, 5):
library_cache.get_cached_library_agents.cache_delete(
mock_user_id, page, 15
)
library_cache.get_cached_library_agents.cache_delete(
mock_user_id, page, 20
)
# Next call should hit database
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_list.call_count == 2
@pytest.mark.asyncio
async def test_delete_agent_clears_multiple_caches(
self, mock_user_id, mock_library_agent_id
):
"""Test that deleting an agent clears both specific and list caches."""
# Clear caches
library_cache.get_cached_library_agent.cache_clear()
library_cache.get_cached_library_agents.cache_clear()
with (
patch.object(
library_db, "get_library_agent", new_callable=AsyncMock
) as mock_get,
patch.object(
library_db, "list_library_agents", new_callable=AsyncMock
) as mock_list,
):
mock_agent = {"id": mock_library_agent_id, "name": "Test Agent"}
mock_get.return_value = mock_agent
mock_list.return_value = {
"agents": [mock_agent],
"total_count": 1,
"page": 1,
"page_size": 20,
}
# Populate caches
await library_cache.get_cached_library_agent(
mock_library_agent_id, mock_user_id
)
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
initial_calls = {
"get": mock_get.call_count,
"list": mock_list.call_count,
}
# Verify cache is used
await library_cache.get_cached_library_agent(
mock_library_agent_id, mock_user_id
)
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_get.call_count == initial_calls["get"]
assert mock_list.call_count == initial_calls["list"]
# Simulate delete_library_agent cache invalidation
library_cache.get_cached_library_agent.cache_delete(
mock_library_agent_id, mock_user_id
)
for page in range(1, 5):
library_cache.get_cached_library_agents.cache_delete(
mock_user_id, page, 15
)
library_cache.get_cached_library_agents.cache_delete(
mock_user_id, page, 20
)
# Next calls should hit database
await library_cache.get_cached_library_agent(
mock_library_agent_id, mock_user_id
)
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_get.call_count == initial_calls["get"] + 1
assert mock_list.call_count == initial_calls["list"] + 1
@pytest.mark.asyncio
async def test_favorites_cache_operations(self, mock_user_id):
"""Test that favorites cache works independently."""
# Clear cache
library_cache.get_cached_library_agent_favorites.cache_clear()
with patch.object(
library_db, "list_favorite_library_agents", new_callable=AsyncMock
) as mock_favs:
mock_response = {"agents": [], "total_count": 0, "page": 1, "page_size": 20}
mock_favs.return_value = mock_response
# First call hits database
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
assert mock_favs.call_count == 1
# Second call uses cache
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
assert mock_favs.call_count == 1 # Cache used
# Clear cache
library_cache.get_cached_library_agent_favorites.cache_delete(
mock_user_id, 1, 20
)
# Next call hits database
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
assert mock_favs.call_count == 2
class TestLibraryPresetCacheInvalidation:
"""Test cache invalidation for library preset operations."""
@pytest.mark.asyncio
async def test_preset_cache_operations(self, mock_user_id):
"""Test preset cache and invalidation."""
# Clear cache
library_cache.get_cached_library_presets.cache_clear()
library_cache.get_cached_library_preset.cache_clear()
preset_id = str(uuid.uuid4())
with (
patch.object(
library_db, "list_presets", new_callable=AsyncMock
) as mock_list,
patch.object(library_db, "get_preset", new_callable=AsyncMock) as mock_get,
):
mock_preset = {"id": preset_id, "name": "Test Preset"}
mock_list.return_value = {
"presets": [mock_preset],
"total_count": 1,
"page": 1,
"page_size": 20,
}
mock_get.return_value = mock_preset
# Populate caches
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
initial_calls = {
"list": mock_list.call_count,
"get": mock_get.call_count,
}
# Verify cache is used
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
assert mock_list.call_count == initial_calls["list"]
assert mock_get.call_count == initial_calls["get"]
# Clear specific preset cache
library_cache.get_cached_library_preset.cache_delete(
preset_id, mock_user_id
)
# Clear list cache
library_cache.get_cached_library_presets.cache_delete(mock_user_id, 1, 20)
# Next calls should hit database
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
assert mock_list.call_count == initial_calls["list"] + 1
assert mock_get.call_count == initial_calls["get"] + 1
class TestLibraryCacheMetrics:
"""Test library cache metrics and management."""
def test_cache_info_structure(self):
"""Test that cache_info returns expected structure."""
info = library_cache.get_cached_library_agents.cache_info()
assert "size" in info
assert "maxsize" in info
assert "ttl_seconds" in info
assert (
info["maxsize"] is None
) # Redis manages its own size with shared_cache=True
assert info["ttl_seconds"] == 600 # 10 minutes
def test_all_library_caches_can_be_cleared(self):
"""Test that all library caches can be cleared."""
# Clear all library caches
library_cache.get_cached_library_agents.cache_clear()
library_cache.get_cached_library_agent_favorites.cache_clear()
library_cache.get_cached_library_agent.cache_clear()
library_cache.get_cached_library_agent_by_graph_id.cache_clear()
library_cache.get_cached_library_agent_by_store_version.cache_clear()
library_cache.get_cached_library_presets.cache_clear()
library_cache.get_cached_library_preset.cache_clear()
# Verify all are empty
assert library_cache.get_cached_library_agents.cache_info()["size"] == 0
assert (
library_cache.get_cached_library_agent_favorites.cache_info()["size"] == 0
)
assert library_cache.get_cached_library_agent.cache_info()["size"] == 0
assert (
library_cache.get_cached_library_agent_by_graph_id.cache_info()["size"] == 0
)
assert (
library_cache.get_cached_library_agent_by_store_version.cache_info()["size"]
== 0
)
assert library_cache.get_cached_library_presets.cache_info()["size"] == 0
assert library_cache.get_cached_library_preset.cache_info()["size"] == 0
def test_cache_ttl_values(self):
"""Test that cache TTL values are set correctly."""
# Library agents - 10 minutes
assert (
library_cache.get_cached_library_agents.cache_info()["ttl_seconds"] == 600
)
# Favorites - 5 minutes (more dynamic)
assert (
library_cache.get_cached_library_agent_favorites.cache_info()["ttl_seconds"]
== 300
)
# Individual agent - 30 minutes
assert (
library_cache.get_cached_library_agent.cache_info()["ttl_seconds"] == 1800
)
# Presets - 30 minutes
assert (
library_cache.get_cached_library_presets.cache_info()["ttl_seconds"] == 1800
)
assert (
library_cache.get_cached_library_preset.cache_info()["ttl_seconds"] == 1800
)

View File

@@ -20,7 +20,7 @@ from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.exceptions import NotFoundError
from backend.util.exceptions import DatabaseError, NotFoundError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.settings import Config
@@ -61,11 +61,11 @@ async def list_library_agents(
if page < 1 or page_size < 1:
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
raise store_exceptions.DatabaseError("Invalid pagination input")
raise DatabaseError("Invalid pagination input")
if search_term and len(search_term.strip()) > 100:
logger.warning(f"Search term too long: {repr(search_term)}")
raise store_exceptions.DatabaseError("Search term is too long")
raise DatabaseError("Search term is too long")
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
@@ -143,7 +143,7 @@ async def list_library_agents(
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agents: {e}")
raise store_exceptions.DatabaseError("Failed to fetch library agents") from e
raise DatabaseError("Failed to fetch library agents") from e
async def list_favorite_library_agents(
@@ -172,7 +172,7 @@ async def list_favorite_library_agents(
if page < 1 or page_size < 1:
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
raise store_exceptions.DatabaseError("Invalid pagination input")
raise DatabaseError("Invalid pagination input")
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
@@ -229,9 +229,7 @@ async def list_favorite_library_agents(
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching favorite library agents: {e}")
raise store_exceptions.DatabaseError(
"Failed to fetch favorite library agents"
) from e
raise DatabaseError("Failed to fetch favorite library agents") from e
async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent:
@@ -273,7 +271,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agent: {e}")
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
raise DatabaseError("Failed to fetch library agent") from e
async def get_library_agent_by_store_version_id(
@@ -338,7 +336,7 @@ async def get_library_agent_by_graph_id(
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agent by graph ID: {e}")
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
raise DatabaseError("Failed to fetch library agent") from e
async def add_generated_agent_image(
@@ -479,9 +477,7 @@ async def update_agent_version_in_library(
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating agent version in library: {e}")
raise store_exceptions.DatabaseError(
"Failed to update agent version in library"
) from e
raise DatabaseError("Failed to update agent version in library") from e
async def update_library_agent(
@@ -544,7 +540,7 @@ async def update_library_agent(
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating library agent: {str(e)}")
raise store_exceptions.DatabaseError("Failed to update library agent") from e
raise DatabaseError("Failed to update library agent") from e
async def delete_library_agent(
@@ -572,7 +568,7 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting library agent: {e}")
raise store_exceptions.DatabaseError("Failed to delete library agent") from e
raise DatabaseError("Failed to delete library agent") from e
async def add_store_agent_to_library(
@@ -663,7 +659,7 @@ async def add_store_agent_to_library(
raise
except prisma.errors.PrismaError as e:
logger.error(f"Database error adding agent to library: {e}")
raise store_exceptions.DatabaseError("Failed to add agent to library") from e
raise DatabaseError("Failed to add agent to library") from e
##############################################
@@ -697,7 +693,7 @@ async def list_presets(
logger.warning(
"Invalid pagination input: page=%d, page_size=%d", page, page_size
)
raise store_exceptions.DatabaseError("Invalid pagination parameters")
raise DatabaseError("Invalid pagination parameters")
query_filter: prisma.types.AgentPresetWhereInput = {
"userId": user_id,
@@ -733,7 +729,7 @@ async def list_presets(
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting presets: {e}")
raise store_exceptions.DatabaseError("Failed to fetch presets") from e
raise DatabaseError("Failed to fetch presets") from e
async def get_preset(
@@ -763,7 +759,7 @@ async def get_preset(
return library_model.LibraryAgentPreset.from_db(preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting preset: {e}")
raise store_exceptions.DatabaseError("Failed to fetch preset") from e
raise DatabaseError("Failed to fetch preset") from e
async def create_preset(
@@ -813,7 +809,7 @@ async def create_preset(
return library_model.LibraryAgentPreset.from_db(new_preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating preset: {e}")
raise store_exceptions.DatabaseError("Failed to create preset") from e
raise DatabaseError("Failed to create preset") from e
async def create_preset_from_graph_execution(
@@ -951,7 +947,7 @@ async def update_preset(
return library_model.LibraryAgentPreset.from_db(updated)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating preset: {e}")
raise store_exceptions.DatabaseError("Failed to update preset") from e
raise DatabaseError("Failed to update preset") from e
async def set_preset_webhook(
@@ -997,7 +993,7 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting preset: {e}")
raise store_exceptions.DatabaseError("Failed to delete preset") from e
raise DatabaseError("Failed to delete preset") from e
async def fork_library_agent(
@@ -1025,7 +1021,7 @@ async def fork_library_agent(
# TODO: once we have open/closed sourced agents this needs to be enabled ~kcze
# + update library/agents/[id]/page.tsx agent actions
# if not original_agent.can_access_graph:
# raise store_exceptions.DatabaseError(
# raise DatabaseError(
# f"User {user_id} cannot access library agent graph {library_agent_id}"
# )
@@ -1039,4 +1035,4 @@ async def fork_library_agent(
return (await create_library_agent(new_graph, user_id))[0]
except prisma.errors.PrismaError as e:
logger.error(f"Database error cloning library agent: {e}")
raise store_exceptions.DatabaseError("Failed to fork library agent") from e
raise DatabaseError("Failed to fork library agent") from e

View File

@@ -5,12 +5,10 @@ import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
from fastapi.responses import Response
import backend.server.cache_config
import backend.server.v2.library.cache as library_cache
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
from backend.util.exceptions import NotFoundError
from backend.util.exceptions import DatabaseError, NotFoundError
logger = logging.getLogger(__name__)
@@ -66,22 +64,13 @@ async def list_library_agents(
HTTPException: If a server/database error occurs.
"""
try:
# Use cache for default queries (no search term, default sort)
if search_term is None and sort_by == library_model.LibraryAgentSort.UPDATED_AT:
return await library_cache.get_cached_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
else:
# Direct DB query for searches and custom sorts
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
except Exception as e:
logger.error(f"Could not list library agents for user #{user_id}: {e}")
raise HTTPException(
@@ -125,7 +114,7 @@ async def list_favorite_library_agents(
HTTPException: If a server/database error occurs.
"""
try:
return await library_cache.get_cached_library_agent_favorites(
return await library_db.list_favorite_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
@@ -143,9 +132,7 @@ async def get_library_agent(
library_agent_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryAgent:
return await library_cache.get_cached_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
@router.get("/by-graph/{graph_id}")
@@ -223,28 +210,18 @@ async def add_marketplace_agent_to_library(
HTTPException(500): If a server/database error occurs.
"""
try:
result = await library_db.add_store_agent_to_library(
return await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
# Clear library caches after adding new agent
for page in range(1, backend.server.cache_config.MAX_PAGES_TO_CLEAR):
library_cache.get_cached_library_agents.cache_delete(
user_id=user_id,
page=page,
page_size=backend.server.cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE,
)
return result
except store_exceptions.AgentNotFoundError as e:
logger.warning(
f"Could not find store listing version {store_listing_version_id} "
"to add to library"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except store_exceptions.DatabaseError as e:
except DatabaseError as e:
logger.error(f"Database error while adding agent to library: {e}", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -286,28 +263,19 @@ async def update_library_agent(
HTTPException(500): If a server/database error occurs.
"""
try:
result = await library_db.update_library_agent(
return await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
)
for page in range(1, backend.server.cache_config.MAX_PAGES_TO_CLEAR):
library_cache.get_cached_library_agent_favorites.cache_delete(
user_id=user_id,
page=page,
page_size=backend.server.cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE,
)
return result
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
except store_exceptions.DatabaseError as e:
except DatabaseError as e:
logger.error(f"Database error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -352,18 +320,6 @@ async def delete_library_agent(
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
# Clear caches after deleting agent
library_cache.get_cached_library_agent.cache_delete(
library_agent_id=library_agent_id, user_id=user_id
)
for page in range(1, backend.server.cache_config.MAX_PAGES_TO_CLEAR):
library_cache.get_cached_library_agents.cache_delete(
user_id=user_id,
page=page,
page_size=backend.server.cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE,
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except NotFoundError as e:
raise HTTPException(

View File

@@ -4,9 +4,6 @@ from typing import Any, Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
import backend.server.cache_config
import backend.server.routers.cache as cache
import backend.server.v2.library.cache as library_cache
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
from backend.data.execution import GraphExecutionMeta
@@ -28,24 +25,6 @@ router = APIRouter(
)
def _clear_presets_list_cache(
user_id: str, num_pages: int = backend.server.cache_config.MAX_PAGES_TO_CLEAR
):
"""
Clear the presets list cache for the given user.
Clears both primary and alternative page sizes for backward compatibility.
"""
page_sizes = backend.server.cache_config.get_page_sizes_for_clearing(
backend.server.cache_config.V2_LIBRARY_PRESETS_PAGE_SIZE,
backend.server.cache_config.V2_LIBRARY_PRESETS_ALT_PAGE_SIZE,
)
for page in range(1, num_pages + 1):
for page_size in page_sizes:
library_cache.get_cached_library_presets.cache_delete(
user_id=user_id, page=page, page_size=page_size
)
@router.get(
"/presets",
summary="List presets",
@@ -72,21 +51,12 @@ async def list_presets(
models.LibraryAgentPresetResponse: A response containing the list of presets.
"""
try:
# Use cache only for default queries (no filter)
if graph_id is None:
return await library_cache.get_cached_library_presets(
user_id=user_id,
page=page,
page_size=page_size,
)
else:
# Direct DB query for filtered requests
return await db.list_presets(
user_id=user_id,
graph_id=graph_id,
page=page,
page_size=page_size,
)
return await db.list_presets(
user_id=user_id,
graph_id=graph_id,
page=page,
page_size=page_size,
)
except Exception as e:
logger.exception("Failed to list presets for user %s: %s", user_id, e)
raise HTTPException(
@@ -117,7 +87,7 @@ async def get_preset(
HTTPException: If the preset is not found or an error occurs.
"""
try:
preset = await library_cache.get_cached_library_preset(preset_id, user_id)
preset = await db.get_preset(user_id, preset_id)
except Exception as e:
logger.exception(
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
@@ -161,13 +131,9 @@ async def create_preset(
"""
try:
if isinstance(preset, models.LibraryAgentPresetCreatable):
result = await db.create_preset(user_id, preset)
return await db.create_preset(user_id, preset)
else:
result = await db.create_preset_from_graph_execution(user_id, preset)
_clear_presets_list_cache(user_id)
return result
return await db.create_preset_from_graph_execution(user_id, preset)
except NotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
@@ -234,9 +200,6 @@ async def setup_trigger(
is_active=True,
),
)
_clear_presets_list_cache(user_id)
return new_preset
@@ -315,13 +278,6 @@ async def update_preset(
description=preset.description,
is_active=preset.is_active,
)
# Clear caches after updating preset
library_cache.get_cached_library_preset.cache_delete(
preset_id=preset_id, user_id=user_id
)
_clear_presets_list_cache(user_id)
except Exception as e:
logger.exception("Preset update failed for user %s: %s", user_id, e)
raise HTTPException(
@@ -395,12 +351,6 @@ async def delete_preset(
try:
await db.delete_preset(user_id, preset_id)
# Clear caches after deleting preset
library_cache.get_cached_library_preset.cache_delete(
preset_id=preset_id, user_id=user_id
)
_clear_presets_list_cache(user_id)
except Exception as e:
logger.exception(
"Error deleting preset %s for user %s: %s", preset_id, user_id, e
@@ -451,33 +401,6 @@ async def execute_preset(
merged_node_input = preset.inputs | inputs
merged_credential_inputs = preset.credentials | credential_inputs
# Clear graph executions cache - use both page sizes for compatibility
for page in range(1, 10):
# Clear with alternative page size
cache.get_cached_graph_executions.cache_delete(
graph_id=preset.graph_id,
user_id=user_id,
page=page,
page_size=backend.server.cache_config.V2_GRAPH_EXECUTIONS_ALT_PAGE_SIZE,
)
cache.get_cached_graph_executions.cache_delete(
user_id=user_id,
page=page,
page_size=backend.server.cache_config.V2_GRAPH_EXECUTIONS_ALT_PAGE_SIZE,
)
# Clear with v1 page size (25)
cache.get_cached_graph_executions.cache_delete(
graph_id=preset.graph_id,
user_id=user_id,
page=page,
page_size=backend.server.cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE,
)
cache.get_cached_graph_executions.cache_delete(
user_id=user_id,
page=page,
page_size=backend.server.cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE,
)
return await add_graph_execution(
user_id=user_id,
graph_id=preset.graph_id,

View File

@@ -179,15 +179,14 @@ async def test_get_favorite_library_agents_success(
def test_get_favorite_library_agents_error(
mocker: pytest_mock.MockFixture, test_user_id: str
):
# Mock the cache function instead of the DB directly since routes now use cache
mock_cache_call = mocker.patch(
"backend.server.v2.library.routes.agents.library_cache.get_cached_library_agent_favorites"
mock_db_call = mocker.patch(
"backend.server.v2.library.db.list_favorite_library_agents"
)
mock_cache_call.side_effect = Exception("Test error")
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents/favorites")
assert response.status_code == 500
mock_cache_call.assert_called_once_with(
mock_db_call.assert_called_once_with(
user_id=test_user_id,
page=1,
page_size=15,

View File

@@ -1,61 +1,22 @@
"""
Cache functions for Store API endpoints.
This module contains all caching decorators and helpers for the Store API,
separated from the main routes for better organization and maintainability.
"""
import backend.server.cache_config
import backend.server.v2.store.db
from backend.util.cache import cached
def _clear_submissions_cache(
user_id: str, num_pages: int = backend.server.cache_config.MAX_PAGES_TO_CLEAR
):
"""
Clear the submissions cache for the given user.
Args:
user_id: User ID whose cache should be cleared
num_pages: Number of pages to clear (default from cache_config)
"""
for page in range(1, num_pages + 1):
_get_cached_submissions.cache_delete(
user_id=user_id,
page=page,
page_size=backend.server.cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE,
)
##############################################
############### Caches #######################
##############################################
def _clear_my_agents_cache(
user_id: str, num_pages: int = backend.server.cache_config.MAX_PAGES_TO_CLEAR
):
"""
Clear the my agents cache for the given user.
Args:
user_id: User ID whose cache should be cleared
num_pages: Number of pages to clear (default from cache_config)
"""
for page in range(1, num_pages + 1):
_get_cached_my_agents.cache_delete(
user_id=user_id,
page=page,
page_size=backend.server.cache_config.V2_MY_AGENTS_PAGE_SIZE,
)
def clear_all_caches():
"""Clear all caches."""
_get_cached_store_agents.cache_clear()
_get_cached_agent_details.cache_clear()
_get_cached_store_creators.cache_clear()
_get_cached_creator_details.cache_clear()
# Cache user profiles for 1 hour per user
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
async def _get_cached_user_profile(user_id: str):
"""Cached helper to get user profile."""
return await backend.server.v2.store.db.get_user_profile(user_id)
# Cache store agents list for 15 minutes
# Cache store agents list for 5 minutes
# Different cache entries for different query combinations
@cached(maxsize=5000, ttl_seconds=900, shared_cache=True)
@cached(maxsize=5000, ttl_seconds=300, shared_cache=True)
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
@@ -78,7 +39,7 @@ async def _get_cached_store_agents(
# Cache individual agent details for 15 minutes
@cached(maxsize=200, ttl_seconds=900, shared_cache=True)
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
async def _get_cached_agent_details(username: str, agent_name: str):
"""Cached helper to get agent details."""
return await backend.server.v2.store.db.get_store_agent_details(
@@ -86,26 +47,8 @@ async def _get_cached_agent_details(username: str, agent_name: str):
)
# Cache agent graphs for 1 hour
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
async def _get_cached_agent_graph(store_listing_version_id: str):
"""Cached helper to get agent graph."""
return await backend.server.v2.store.db.get_available_graph(
store_listing_version_id
)
# Cache agent by version for 1 hour
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
"""Cached helper to get store agent by version ID."""
return await backend.server.v2.store.db.get_store_agent_by_version_id(
store_listing_version_id
)
# Cache creators list for 1 hour
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
# Cache creators list for 5 minutes
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
@@ -123,30 +66,10 @@ async def _get_cached_store_creators(
)
# Cache individual creator details for 1 hour
@cached(maxsize=100, ttl_seconds=3600, shared_cache=True)
# Cache individual creator details for 5 minutes
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await backend.server.v2.store.db.get_store_creator_details(
username=username.lower()
)
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
@cached(maxsize=500, ttl_seconds=300, shared_cache=True)
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
"""Cached helper to get user's agents."""
return await backend.server.v2.store.db.get_my_agents(
user_id, page=page, page_size=page_size
)
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
"""Cached helper to get user's submissions."""
return await backend.server.v2.store.db.get_store_submissions(
user_id=user_id,
page=page,
page_size=page_size,
)

View File

@@ -25,6 +25,7 @@ from backend.data.notifications import (
NotificationEventModel,
)
from backend.notifications.notifications import queue_notification_async
from backend.util.exceptions import DatabaseError
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -70,8 +71,7 @@ async def get_store_agents(
logger.debug(
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
)
sanitized_query = sanitize_query(search_query)
search_term = sanitize_query(search_query)
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
if featured:
where_clause["featured"] = featured
@@ -80,10 +80,10 @@ async def get_store_agents(
if category:
where_clause["categories"] = {"has": category}
if sanitized_query:
if search_term:
where_clause["OR"] = [
{"agent_name": {"contains": sanitized_query, "mode": "insensitive"}},
{"description": {"contains": sanitized_query, "mode": "insensitive"}},
{"agent_name": {"contains": search_term, "mode": "insensitive"}},
{"description": {"contains": search_term, "mode": "insensitive"}},
]
order_by = []
@@ -142,9 +142,25 @@ async def get_store_agents(
)
except Exception as e:
logger.error(f"Error getting store agents: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch store agents"
) from e
raise DatabaseError("Failed to fetch store agents") from e
# TODO: commenting this out as we concerned about potential db load issues
# finally:
# if search_term:
# await log_search_term(search_query=search_term)
async def log_search_term(search_query: str):
"""Log a search term to the database"""
# Anonymize the data by preventing correlation with other logs
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
try:
await prisma.models.SearchTerms.prisma().create(
data={"searchTerm": search_query, "createdDate": date}
)
except Exception as e:
# Fail silently here so that logging search terms doesn't break the app
logger.error(f"Error logging search term: {e}")
async def get_store_agent_details(
@@ -237,9 +253,7 @@ async def get_store_agent_details(
raise
except Exception as e:
logger.error(f"Error getting store agent details: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent details"
) from e
raise DatabaseError("Failed to fetch agent details") from e
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
@@ -266,9 +280,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
except Exception as e:
logger.error(f"Error getting agent: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent"
) from e
raise DatabaseError("Failed to fetch agent") from e
async def get_store_agent_by_version_id(
@@ -308,9 +320,7 @@ async def get_store_agent_by_version_id(
raise
except Exception as e:
logger.error(f"Error getting store agent details: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent details"
) from e
raise DatabaseError("Failed to fetch agent details") from e
async def get_store_creators(
@@ -336,9 +346,7 @@ async def get_store_creators(
# Sanitize and validate search query by escaping special characters
sanitized_query = search_query.strip()
if not sanitized_query or len(sanitized_query) > 100: # Reasonable length limit
raise backend.server.v2.store.exceptions.DatabaseError(
"Invalid search query"
)
raise DatabaseError("Invalid search query")
# Escape special SQL characters
sanitized_query = (
@@ -364,11 +372,9 @@ async def get_store_creators(
try:
# Validate pagination parameters
if not isinstance(page, int) or page < 1:
raise backend.server.v2.store.exceptions.DatabaseError(
"Invalid page number"
)
raise DatabaseError("Invalid page number")
if not isinstance(page_size, int) or page_size < 1 or page_size > 100:
raise backend.server.v2.store.exceptions.DatabaseError("Invalid page size")
raise DatabaseError("Invalid page size")
# Get total count for pagination using sanitized where clause
total = await prisma.models.Creator.prisma().count(
@@ -423,9 +429,7 @@ async def get_store_creators(
)
except Exception as e:
logger.error(f"Error getting store creators: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch store creators"
) from e
raise DatabaseError("Failed to fetch store creators") from e
async def get_store_creator_details(
@@ -460,9 +464,7 @@ async def get_store_creator_details(
raise
except Exception as e:
logger.error(f"Error getting store creator details: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch creator details"
) from e
raise DatabaseError("Failed to fetch creator details") from e
async def get_store_submissions(
@@ -493,7 +495,6 @@ async def get_store_submissions(
submission_models = []
for sub in submissions:
submission_model = backend.server.v2.store.model.StoreSubmission(
user_id=sub.user_id,
agent_id=sub.agent_id,
agent_version=sub.agent_version,
name=sub.name,
@@ -711,7 +712,6 @@ async def create_store_submission(
logger.debug(f"Created store listing for agent {agent_id}")
# Return submission details
return backend.server.v2.store.model.StoreSubmission(
user_id=user_id,
agent_id=agent_id,
agent_version=agent_version,
name=name,
@@ -727,7 +727,21 @@ async def create_store_submission(
store_listing_version_id=store_listing_version_id,
changes_summary=changes_summary,
)
except prisma.errors.UniqueViolationError as exc:
# Attempt to check if the error was due to the slug field being unique
error_str = str(exc)
if "slug" in error_str.lower():
logger.debug(
f"Slug '{slug}' is already in use by another agent (agent_id: {agent_id}) for user {user_id}"
)
raise backend.server.v2.store.exceptions.SlugAlreadyInUseError(
f"The URL slug '{slug}' is already in use by another one of your agents. Please choose a different slug."
) from exc
else:
# Reraise as a generic database error for other unique violations
raise DatabaseError(
f"Unique constraint violated (not slug): {error_str}"
) from exc
except (
backend.server.v2.store.exceptions.AgentNotFoundError,
backend.server.v2.store.exceptions.ListingExistsError,
@@ -735,9 +749,7 @@ async def create_store_submission(
raise
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating store submission: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create store submission"
) from e
raise DatabaseError("Failed to create store submission") from e
async def edit_store_submission(
@@ -858,11 +870,8 @@ async def edit_store_submission(
)
if not updated_version:
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update store listing version"
)
raise DatabaseError("Failed to update store listing version")
return backend.server.v2.store.model.StoreSubmission(
user_id=user_id,
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
name=name,
@@ -897,9 +906,7 @@ async def edit_store_submission(
raise
except prisma.errors.PrismaError as e:
logger.error(f"Database error editing store submission: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to edit store submission"
) from e
raise DatabaseError("Failed to edit store submission") from e
async def create_store_version(
@@ -996,7 +1003,6 @@ async def create_store_version(
)
# Return submission details
return backend.server.v2.store.model.StoreSubmission(
user_id=user_id,
agent_id=agent_id,
agent_version=agent_version,
name=name,
@@ -1015,9 +1021,7 @@ async def create_store_version(
)
except prisma.errors.PrismaError as e:
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create new store version"
) from e
raise DatabaseError("Failed to create new store version") from e
async def create_store_review(
@@ -1057,9 +1061,7 @@ async def create_store_review(
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating store review: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create store review"
) from e
raise DatabaseError("Failed to create store review") from e
async def get_user_profile(
@@ -1083,9 +1085,7 @@ async def get_user_profile(
)
except Exception as e:
logger.error(f"Error getting user profile: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to get user profile"
) from e
raise DatabaseError("Failed to get user profile") from e
async def update_profile(
@@ -1122,7 +1122,7 @@ async def update_profile(
logger.error(
f"Unauthorized update attempt for profile {existing_profile.id} by user {user_id}"
)
raise backend.server.v2.store.exceptions.DatabaseError(
raise DatabaseError(
f"Unauthorized update attempt for profile {existing_profile.id} by user {user_id}"
)
@@ -1147,9 +1147,7 @@ async def update_profile(
)
if updated_profile is None:
logger.error(f"Failed to update profile for user {user_id}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update profile"
)
raise DatabaseError("Failed to update profile")
return backend.server.v2.store.model.CreatorDetails(
name=updated_profile.name,
@@ -1164,9 +1162,7 @@ async def update_profile(
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating profile: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update profile"
) from e
raise DatabaseError("Failed to update profile") from e
async def get_my_agents(
@@ -1234,9 +1230,7 @@ async def get_my_agents(
)
except Exception as e:
logger.error(f"Error getting my agents: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch my agents"
) from e
raise DatabaseError("Failed to fetch my agents") from e
async def get_agent(store_listing_version_id: str) -> GraphModel:
@@ -1497,8 +1491,8 @@ async def review_store_submission(
include={"StoreListing": True},
)
if not submission or not submission.StoreListing:
raise backend.server.v2.store.exceptions.DatabaseError(
if not submission:
raise DatabaseError(
f"Failed to update store listing version {store_listing_version_id}"
)
@@ -1587,7 +1581,6 @@ async def review_store_submission(
# Convert to Pydantic model for consistency
return backend.server.v2.store.model.StoreSubmission(
user_id=submission.StoreListing.owningUserId,
agent_id=submission.agentGraphId,
agent_version=submission.agentGraphVersion,
name=submission.name,
@@ -1614,9 +1607,7 @@ async def review_store_submission(
except Exception as e:
logger.error(f"Could not create store submission review: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create store submission review"
) from e
raise DatabaseError("Failed to create store submission review") from e
async def get_admin_listings_with_versions(
@@ -1720,17 +1711,14 @@ async def get_admin_listings_with_versions(
# Get total count for pagination
total = await prisma.models.StoreListing.prisma().count(where=where)
total_pages = (total + page_size - 1) // page_size
# Convert to response models
listings_with_versions = []
for listing in listings:
versions: list[backend.server.v2.store.model.StoreSubmission] = []
if not listing.OwningUser:
logger.error(f"Listing {listing.id} has no owning user")
continue
# If we have versions, turn them into StoreSubmission models
for version in listing.Versions or []:
version_model = backend.server.v2.store.model.StoreSubmission(
user_id=listing.OwningUser.id,
agent_id=version.agentGraphId,
agent_version=version.agentGraphVersion,
name=version.name,
@@ -1798,6 +1786,27 @@ async def get_admin_listings_with_versions(
)
async def check_submission_already_approved(
store_listing_version_id: str,
) -> bool:
"""Check the submission status of a store listing version."""
try:
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}
)
)
if not store_listing_version:
return False
return (
store_listing_version.submissionStatus
== prisma.enums.SubmissionStatus.APPROVED
)
except Exception as e:
logger.error(f"Error checking submission status: {e}")
return False
async def get_agent_as_admin(
user_id: str | None,
store_listing_version_id: str,

View File

@@ -42,6 +42,7 @@ async def test_get_store_agents(mocker):
versions=["1.0"],
updated_at=datetime.now(),
is_available=False,
useForOnboarding=False,
)
]
@@ -84,6 +85,7 @@ async def test_get_store_agent_details(mocker):
versions=["1.0"],
updated_at=datetime.now(),
is_available=False,
useForOnboarding=False,
)
# Mock active version agent (what we want to return for active version)
@@ -105,6 +107,7 @@ async def test_get_store_agent_details(mocker):
versions=["1.0", "2.0"],
updated_at=datetime.now(),
is_available=True,
useForOnboarding=False,
)
# Create a mock StoreListing result
@@ -248,6 +251,7 @@ async def test_create_store_submission(mocker):
isAvailable=True,
)
],
useForOnboarding=False,
)
# Mock prisma calls
@@ -275,7 +279,6 @@ async def test_create_store_submission(mocker):
# Verify mocks called correctly
mock_agent_graph.return_value.find_first.assert_called_once()
mock_store_listing.return_value.find_first.assert_called_once()
mock_store_listing.return_value.create.assert_called_once()

View File

@@ -1,4 +1,7 @@
class MediaUploadError(Exception):
from backend.util.exceptions import NotFoundError
class MediaUploadError(ValueError):
"""Base exception for media upload errors"""
pass
@@ -48,19 +51,19 @@ class VirusScanError(MediaUploadError):
pass
class StoreError(Exception):
class StoreError(ValueError):
"""Base exception for store-related errors"""
pass
class AgentNotFoundError(StoreError):
class AgentNotFoundError(NotFoundError):
"""Raised when an agent is not found"""
pass
class CreatorNotFoundError(StoreError):
class CreatorNotFoundError(NotFoundError):
"""Raised when a creator is not found"""
pass
@@ -72,25 +75,19 @@ class ListingExistsError(StoreError):
pass
class DatabaseError(StoreError):
"""Raised when there is an error interacting with the database"""
pass
class ProfileNotFoundError(StoreError):
class ProfileNotFoundError(NotFoundError):
"""Raised when a profile is not found"""
pass
class ListingNotFoundError(StoreError):
class ListingNotFoundError(NotFoundError):
"""Raised when a store listing is not found"""
pass
class SubmissionNotFoundError(StoreError):
class SubmissionNotFoundError(NotFoundError):
"""Raised when a submission is not found"""
pass
@@ -106,3 +103,9 @@ class UnauthorizedError(StoreError):
"""Raised when a user is not authorized to perform an action"""
pass
class SlugAlreadyInUseError(StoreError):
"""Raised when a slug is already in use by another agent owned by the user"""
pass

View File

@@ -98,7 +98,6 @@ class Profile(pydantic.BaseModel):
class StoreSubmission(pydantic.BaseModel):
user_id: str = pydantic.Field(default="", exclude=True)
agent_id: str
agent_version: int
name: str

View File

@@ -135,7 +135,6 @@ def test_creator_details():
def test_store_submission():
submission = backend.server.v2.store.model.StoreSubmission(
user_id="user123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
@@ -157,7 +156,6 @@ def test_store_submissions_response():
response = backend.server.v2.store.model.StoreSubmissionsResponse(
submissions=[
backend.server.v2.store.model.StoreSubmission(
user_id="user123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",

View File

@@ -8,25 +8,13 @@ import fastapi
import fastapi.responses
import backend.data.graph
import backend.server.cache_config
import backend.server.v2.store.cache as store_cache
import backend.server.v2.store.db
import backend.server.v2.store.exceptions
import backend.server.v2.store.image_gen
import backend.server.v2.store.media
import backend.server.v2.store.model
import backend.util.json
from backend.server.v2.store.cache import (
_clear_submissions_cache,
_get_cached_agent_details,
_get_cached_agent_graph,
_get_cached_creator_details,
_get_cached_my_agents,
_get_cached_store_agent_by_version,
_get_cached_store_agents,
_get_cached_store_creators,
_get_cached_submissions,
_get_cached_user_profile,
)
logger = logging.getLogger(__name__)
@@ -53,7 +41,7 @@ async def get_profile(
Cached for 1 hour per user.
"""
try:
profile = await _get_cached_user_profile(user_id)
profile = await backend.server.v2.store.db.get_user_profile(user_id)
if profile is None:
return fastapi.responses.JSONResponse(
status_code=404,
@@ -99,8 +87,6 @@ async def update_or_create_profile(
updated_profile = await backend.server.v2.store.db.update_profile(
user_id=user_id, profile=profile
)
# Clear the cache for this user after profile update
_get_cached_user_profile.cache_delete(user_id)
return updated_profile
except Exception as e:
logger.exception("Failed to update profile for user %s: %s", user_id, e)
@@ -131,11 +117,10 @@ async def get_agents(
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = backend.server.cache_config.V2_STORE_AGENTS_PAGE_SIZE,
page_size: int = 20,
):
"""
Get a paginated list of agents from the store with optional filtering and sorting.
Results are cached for 15 minutes.
Args:
featured (bool, optional): Filter to only show featured agents. Defaults to False.
@@ -171,7 +156,7 @@ async def get_agents(
)
try:
agents = await _get_cached_store_agents(
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
@@ -201,7 +186,6 @@ async def get_agents(
async def get_agent(username: str, agent_name: str):
"""
This is only used on the AgentDetails Page.
Results are cached for 15 minutes.
It returns the store listing agents details.
"""
@@ -209,7 +193,7 @@ async def get_agent(username: str, agent_name: str):
username = urllib.parse.unquote(username).lower()
# URL decode the agent name since it comes from the URL path
agent_name = urllib.parse.unquote(agent_name).lower()
agent = await _get_cached_agent_details(
agent = await store_cache._get_cached_agent_details(
username=username, agent_name=agent_name
)
return agent
@@ -232,10 +216,11 @@ async def get_agent(username: str, agent_name: str):
async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: str):
"""
Get Agent Graph from Store Listing Version ID.
Results are cached for 1 hour.
"""
try:
graph = await _get_cached_agent_graph(store_listing_version_id)
graph = await backend.server.v2.store.db.get_available_graph(
store_listing_version_id
)
return graph
except Exception:
logger.exception("Exception occurred whilst getting agent graph")
@@ -255,10 +240,12 @@ async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: s
async def get_store_agent(store_listing_version_id: str):
"""
Get Store Agent Details from Store Listing Version ID.
Results are cached for 1 hour.
"""
try:
agent = await _get_cached_store_agent_by_version(store_listing_version_id)
agent = await backend.server.v2.store.db.get_store_agent_by_version_id(
store_listing_version_id
)
return agent
except Exception:
logger.exception("Exception occurred whilst getting store agent")
@@ -329,15 +316,13 @@ async def get_creators(
search_query: str | None = None,
sorted_by: str | None = None,
page: int = 1,
page_size: int = backend.server.cache_config.V2_STORE_CREATORS_PAGE_SIZE,
page_size: int = 20,
):
"""
This is needed for:
- Home Page Featured Creators
- Search Results Page
Results are cached for 1 hour.
---
To support this functionality we need:
@@ -356,7 +341,7 @@ async def get_creators(
)
try:
creators = await _get_cached_store_creators(
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
sorted_by=sorted_by,
@@ -383,12 +368,11 @@ async def get_creator(
):
"""
Get the details of a creator.
Results are cached for 1 hour.
- Creator Details Page
"""
try:
username = urllib.parse.unquote(username).lower()
creator = await _get_cached_creator_details(username=username)
creator = await store_cache._get_cached_creator_details(username=username)
return creator
except Exception:
logger.exception("Exception occurred whilst getting creator details")
@@ -415,16 +399,15 @@ async def get_creator(
async def get_my_agents(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
page_size: typing.Annotated[
int, fastapi.Query(ge=1)
] = backend.server.cache_config.V2_MY_AGENTS_PAGE_SIZE,
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
):
"""
Get user's own agents.
Results are cached for 5 minutes per user.
"""
try:
agents = await _get_cached_my_agents(user_id, page=page, page_size=page_size)
agents = await backend.server.v2.store.db.get_my_agents(
user_id, page=page, page_size=page_size
)
return agents
except Exception:
logger.exception("Exception occurred whilst getting my agents")
@@ -461,10 +444,6 @@ async def delete_submission(
submission_id=submission_id,
)
# Clear submissions cache for this specific user after deletion
if result:
_clear_submissions_cache(user_id)
return result
except Exception:
logger.exception("Exception occurred whilst deleting store submission")
@@ -484,11 +463,10 @@ async def delete_submission(
async def get_submissions(
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
page: int = 1,
page_size: int = backend.server.cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE,
page_size: int = 20,
):
"""
Get a paginated list of store submissions for the authenticated user.
Results are cached for 1 hour per user.
Args:
user_id (str): ID of the authenticated user
@@ -511,8 +489,10 @@ async def get_submissions(
status_code=422, detail="Page size must be greater than 0"
)
try:
listings = await _get_cached_submissions(
user_id, page=page, page_size=page_size
listings = await backend.server.v2.store.db.get_store_submissions(
user_id=user_id,
page=page,
page_size=page_size,
)
return listings
except Exception:
@@ -566,9 +546,13 @@ async def create_submission(
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
_clear_submissions_cache(user_id)
return result
except backend.server.v2.store.exceptions.SlugAlreadyInUseError as e:
logger.warning("Slug already in use: %s", str(e))
return fastapi.responses.JSONResponse(
status_code=409,
content={"detail": str(e)},
)
except Exception:
logger.exception("Exception occurred whilst creating store submission")
return fastapi.responses.JSONResponse(
@@ -617,8 +601,6 @@ async def edit_submission(
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
_clear_submissions_cache(user_id)
return result
@@ -811,15 +793,10 @@ async def get_cache_metrics():
)
# Add metrics for each cache
add_cache_metrics("user_profile", _get_cached_user_profile)
add_cache_metrics("store_agents", _get_cached_store_agents)
add_cache_metrics("agent_details", _get_cached_agent_details)
add_cache_metrics("agent_graph", _get_cached_agent_graph)
add_cache_metrics("agent_by_version", _get_cached_store_agent_by_version)
add_cache_metrics("store_creators", _get_cached_store_creators)
add_cache_metrics("creator_details", _get_cached_creator_details)
add_cache_metrics("my_agents", _get_cached_my_agents)
add_cache_metrics("submissions", _get_cached_submissions)
add_cache_metrics("store_agents", store_cache._get_cached_store_agents)
add_cache_metrics("agent_details", store_cache._get_cached_agent_details)
add_cache_metrics("store_creators", store_cache._get_cached_store_creators)
add_cache_metrics("creator_details", store_cache._get_cached_creator_details)
# Add metadata/help text at the beginning
prometheus_output = [

View File

@@ -534,7 +534,6 @@ def test_get_submissions_success(
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
submissions=[
backend.server.v2.store.model.StoreSubmission(
user_id="user123",
name="Test Agent",
description="Test agent description",
image_urls=["test.jpg"],

View File

@@ -4,18 +4,12 @@ Test suite for verifying cache_delete functionality in store routes.
Tests that specific cache entries can be deleted while preserving others.
"""
import datetime
from unittest.mock import AsyncMock, patch
import pytest
from backend.server.v2.store import routes
from backend.server.v2.store.model import (
ProfileDetails,
StoreAgent,
StoreAgentDetails,
StoreAgentsResponse,
)
from backend.server.v2.store import cache as store_cache
from backend.server.v2.store.model import StoreAgent, StoreAgentsResponse
from backend.util.models import Pagination
@@ -54,10 +48,10 @@ class TestCacheDeletion:
return_value=mock_response,
) as mock_db:
# Clear cache first
routes._get_cached_store_agents.cache_clear()
store_cache._get_cached_store_agents.cache_clear()
# First call - should hit database
result1 = await routes._get_cached_store_agents(
result1 = await store_cache._get_cached_store_agents(
featured=False,
creator=None,
sorted_by=None,
@@ -70,7 +64,7 @@ class TestCacheDeletion:
assert result1.agents[0].agent_name == "Test Agent"
# Second call with same params - should use cache
await routes._get_cached_store_agents(
await store_cache._get_cached_store_agents(
featured=False,
creator=None,
sorted_by=None,
@@ -82,7 +76,7 @@ class TestCacheDeletion:
assert mock_db.call_count == 1 # No additional DB call
# Third call with different params - should hit database
await routes._get_cached_store_agents(
await store_cache._get_cached_store_agents(
featured=True, # Different param
creator=None,
sorted_by=None,
@@ -94,7 +88,7 @@ class TestCacheDeletion:
assert mock_db.call_count == 2 # New DB call
# Delete specific cache entry
deleted = routes._get_cached_store_agents.cache_delete(
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=False,
creator=None,
sorted_by=None,
@@ -106,7 +100,7 @@ class TestCacheDeletion:
assert deleted is True # Entry was deleted
# Try to delete non-existent entry
deleted = routes._get_cached_store_agents.cache_delete(
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=False,
creator="nonexistent",
sorted_by=None,
@@ -118,7 +112,7 @@ class TestCacheDeletion:
assert deleted is False # Entry didn't exist
# Call with deleted params - should hit database again
await routes._get_cached_store_agents(
await store_cache._get_cached_store_agents(
featured=False,
creator=None,
sorted_by=None,
@@ -130,7 +124,7 @@ class TestCacheDeletion:
assert mock_db.call_count == 3 # New DB call after deletion
# Call with featured=True - should still be cached
await routes._get_cached_store_agents(
await store_cache._get_cached_store_agents(
featured=True,
creator=None,
sorted_by=None,
@@ -141,105 +135,11 @@ class TestCacheDeletion:
)
assert mock_db.call_count == 3 # No additional DB call
@pytest.mark.asyncio
async def test_agent_details_cache_delete(self):
"""Test that specific agent details cache entries can be deleted."""
mock_response = StoreAgentDetails(
store_listing_version_id="version1",
slug="test-agent",
agent_name="Test Agent",
agent_video="https://example.com/video.mp4",
agent_image=["https://example.com/image.jpg"],
creator="testuser",
creator_avatar="https://example.com/avatar.jpg",
sub_heading="Test subheading",
description="Test description",
categories=["productivity"],
runs=100,
rating=4.5,
versions=[],
last_updated=datetime.datetime(2024, 1, 1),
)
with patch(
"backend.server.v2.store.db.get_store_agent_details",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_db:
# Clear cache first
routes._get_cached_agent_details.cache_clear()
# First call - should hit database
await routes._get_cached_agent_details(
username="testuser", agent_name="testagent"
)
assert mock_db.call_count == 1
# Second call - should use cache
await routes._get_cached_agent_details(
username="testuser", agent_name="testagent"
)
assert mock_db.call_count == 1 # No additional DB call
# Delete specific entry
deleted = routes._get_cached_agent_details.cache_delete(
username="testuser", agent_name="testagent"
)
assert deleted is True
# Call again - should hit database
await routes._get_cached_agent_details(
username="testuser", agent_name="testagent"
)
assert mock_db.call_count == 2 # New DB call after deletion
@pytest.mark.asyncio
async def test_user_profile_cache_delete(self):
"""Test that user profile cache entries can be deleted."""
mock_response = ProfileDetails(
name="Test User",
username="testuser",
description="Test profile",
links=["https://example.com"],
)
with patch(
"backend.server.v2.store.db.get_user_profile",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_db:
# Clear cache first
routes._get_cached_user_profile.cache_clear()
# First call - should hit database
await routes._get_cached_user_profile("user123")
assert mock_db.call_count == 1
# Second call - should use cache
await routes._get_cached_user_profile("user123")
assert mock_db.call_count == 1
# Different user - should hit database
await routes._get_cached_user_profile("user456")
assert mock_db.call_count == 2
# Delete specific user's cache
deleted = routes._get_cached_user_profile.cache_delete("user123")
assert deleted is True
# user123 should hit database again
await routes._get_cached_user_profile("user123")
assert mock_db.call_count == 3
# user456 should still be cached
await routes._get_cached_user_profile("user456")
assert mock_db.call_count == 3 # No additional DB call
@pytest.mark.asyncio
async def test_cache_info_after_deletions(self):
"""Test that cache_info correctly reflects deletions."""
# Clear all caches first
routes._get_cached_store_agents.cache_clear()
store_cache._get_cached_store_agents.cache_clear()
mock_response = StoreAgentsResponse(
agents=[],
@@ -258,7 +158,7 @@ class TestCacheDeletion:
):
# Add multiple entries
for i in range(5):
await routes._get_cached_store_agents(
await store_cache._get_cached_store_agents(
featured=False,
creator=f"creator{i}",
sorted_by=None,
@@ -269,12 +169,12 @@ class TestCacheDeletion:
)
# Check cache size
info = routes._get_cached_store_agents.cache_info()
info = store_cache._get_cached_store_agents.cache_info()
assert info["size"] == 5
# Delete some entries
for i in range(2):
deleted = routes._get_cached_store_agents.cache_delete(
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=False,
creator=f"creator{i}",
sorted_by=None,
@@ -286,7 +186,7 @@ class TestCacheDeletion:
assert deleted is True
# Check cache size after deletion
info = routes._get_cached_store_agents.cache_info()
info = store_cache._get_cached_store_agents.cache_info()
assert info["size"] == 3
@pytest.mark.asyncio
@@ -307,10 +207,10 @@ class TestCacheDeletion:
new_callable=AsyncMock,
return_value=mock_response,
) as mock_db:
routes._get_cached_store_agents.cache_clear()
store_cache._get_cached_store_agents.cache_clear()
# Test with all parameters
await routes._get_cached_store_agents(
await store_cache._get_cached_store_agents(
featured=True,
creator="testuser",
sorted_by="rating",
@@ -322,7 +222,7 @@ class TestCacheDeletion:
assert mock_db.call_count == 1
# Delete with exact same parameters
deleted = routes._get_cached_store_agents.cache_delete(
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by="rating",
@@ -334,7 +234,7 @@ class TestCacheDeletion:
assert deleted is True
# Try to delete with slightly different parameters
deleted = routes._get_cached_store_agents.cache_delete(
deleted = store_cache._get_cached_store_agents.cache_delete(
featured=True,
creator="testuser",
sorted_by="rating",
@@ -345,150 +245,6 @@ class TestCacheDeletion:
)
assert deleted is False # Different parameters, not in cache
@pytest.mark.asyncio
async def test_clear_submissions_cache_page_size_consistency(self):
"""
Test that _clear_submissions_cache uses the correct page_size.
This test ensures that if the default page_size in routes changes,
the hardcoded value in _clear_submissions_cache must also change.
"""
from backend.server.v2.store.model import StoreSubmissionsResponse
mock_response = StoreSubmissionsResponse(
submissions=[],
pagination=Pagination(
total_items=0,
total_pages=1,
current_page=1,
page_size=20,
),
)
with patch(
"backend.server.v2.store.db.get_store_submissions",
new_callable=AsyncMock,
return_value=mock_response,
):
# Clear cache first
routes._get_cached_submissions.cache_clear()
# Populate cache with multiple pages using the default page_size
DEFAULT_PAGE_SIZE = 20 # This should match the default in routes.py
user_id = "test_user"
# Add entries for pages 1-5
for page in range(1, 6):
await routes._get_cached_submissions(
user_id=user_id, page=page, page_size=DEFAULT_PAGE_SIZE
)
# Verify cache has entries
cache_info_before = routes._get_cached_submissions.cache_info()
assert cache_info_before["size"] == 5
# Call _clear_submissions_cache
routes._clear_submissions_cache(user_id, num_pages=20)
# All entries should be cleared
cache_info_after = routes._get_cached_submissions.cache_info()
assert (
cache_info_after["size"] == 0
), "Cache should be empty after _clear_submissions_cache"
@pytest.mark.asyncio
async def test_clear_submissions_cache_detects_page_size_mismatch(self):
"""
Test that detects if _clear_submissions_cache is using wrong page_size.
If this test fails, it means the hardcoded page_size in _clear_submissions_cache
doesn't match the default page_size used in the routes.
"""
from backend.server.v2.store.model import StoreSubmissionsResponse
mock_response = StoreSubmissionsResponse(
submissions=[],
pagination=Pagination(
total_items=0,
total_pages=1,
current_page=1,
page_size=20,
),
)
with patch(
"backend.server.v2.store.db.get_store_submissions",
new_callable=AsyncMock,
return_value=mock_response,
):
# Clear cache first
routes._get_cached_submissions.cache_clear()
# WRONG_PAGE_SIZE simulates what happens if someone changes
# the default page_size in routes but forgets to update _clear_submissions_cache
WRONG_PAGE_SIZE = 25 # Different from the hardcoded value in cache.py
user_id = "test_user"
# Populate cache with the "wrong" page_size
for page in range(1, 6):
await routes._get_cached_submissions(
user_id=user_id, page=page, page_size=WRONG_PAGE_SIZE
)
# Verify cache has entries
cache_info_before = routes._get_cached_submissions.cache_info()
assert cache_info_before["size"] == 5
# Call _clear_submissions_cache (which uses page_size=20 hardcoded)
routes._clear_submissions_cache(user_id, num_pages=20)
# If page_size is mismatched, entries won't be cleared
cache_info_after = routes._get_cached_submissions.cache_info()
# This assertion will FAIL if _clear_submissions_cache uses wrong page_size
assert (
cache_info_after["size"] == 5
), "Cache entries with different page_size should NOT be cleared (this is expected)"
@pytest.mark.asyncio
async def test_my_agents_cache_needs_clearing_too(self):
"""
Test that demonstrates _get_cached_my_agents also needs cache clearing.
Currently there's no _clear_my_agents_cache function, but there should be.
"""
from backend.server.v2.store.model import MyAgentsResponse
mock_response = MyAgentsResponse(
agents=[],
pagination=Pagination(
total_items=0,
total_pages=1,
current_page=1,
page_size=20,
),
)
with patch(
"backend.server.v2.store.db.get_my_agents",
new_callable=AsyncMock,
return_value=mock_response,
):
routes._get_cached_my_agents.cache_clear()
DEFAULT_PAGE_SIZE = 20
user_id = "test_user"
# Populate cache
for page in range(1, 6):
await routes._get_cached_my_agents(
user_id=user_id, page=page, page_size=DEFAULT_PAGE_SIZE
)
cache_info = routes._get_cached_my_agents.cache_info()
assert cache_info["size"] == 5
# NOTE: Currently there's no _clear_my_agents_cache function
# If we implement one, it should clear all pages consistently
# For now we document this as a TODO
if __name__ == "__main__":
# Run the tests

View File

@@ -12,6 +12,7 @@ Provides decorators for caching function results with support for:
import asyncio
import inspect
import logging
import pickle
import threading
import time
from dataclasses import dataclass
@@ -58,9 +59,7 @@ def _get_cache_pool() -> ConnectionPool:
return _cache_pool
def _get_redis_client() -> Redis:
"""Get a Redis client from the connection pool."""
return Redis(connection_pool=_get_cache_pool())
redis = Redis(connection_pool=_get_cache_pool())
@dataclass
@@ -110,11 +109,11 @@ def _make_hashable_key(
return (hashable_args, hashable_kwargs)
def _make_redis_key(key: tuple[Any, ...]) -> str:
def _make_redis_key(key: tuple[Any, ...], func_name: str) -> str:
"""Convert a hashable key tuple to a Redis key string."""
# Ensure key is already hashable
hashable_key = key if isinstance(key, tuple) else (key,)
return f"cache:{hash(hashable_key)}"
return f"cache:{func_name}:{hash(hashable_key)}"
@runtime_checkable
@@ -177,9 +176,6 @@ def cached(
def _get_from_redis(redis_key: str) -> Any | None:
"""Get value from Redis, optionally refreshing TTL."""
try:
import pickle
redis = _get_redis_client()
if refresh_ttl_on_get:
# Use GETEX to get value and refresh expiry atomically
cached_bytes = redis.getex(redis_key, ex=ttl_seconds)
@@ -197,9 +193,6 @@ def cached(
def _set_to_redis(redis_key: str, value: Any) -> None:
"""Set value in Redis with TTL."""
try:
import pickle
redis = _get_redis_client()
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
redis.setex(redis_key, ttl_seconds, pickled_value)
except Exception as e:
@@ -239,13 +232,15 @@ def cached(
loop = None
if loop not in _event_loop_locks:
_event_loop_locks[loop] = asyncio.Lock()
return _event_loop_locks.setdefault(loop, asyncio.Lock())
return _event_loop_locks[loop]
@wraps(target_func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
redis_key = _make_redis_key(key) if shared_cache else ""
redis_key = (
_make_redis_key(key, target_func.__name__) if shared_cache else ""
)
# Fast path: check cache without lock
if shared_cache:
@@ -290,7 +285,9 @@ def cached(
@wraps(target_func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
redis_key = _make_redis_key(key) if shared_cache else ""
redis_key = (
_make_redis_key(key, target_func.__name__) if shared_cache else ""
)
# Fast path: check cache without lock
if shared_cache:
@@ -332,13 +329,14 @@ def cached(
def cache_clear(pattern: str | None = None) -> None:
"""Clear cache entries. If pattern provided, clear matching entries."""
if shared_cache:
redis = _get_redis_client()
if pattern:
# Clear entries matching pattern
keys = list(redis.scan_iter(f"cache:{pattern}", count=100))
keys = list(
redis.scan_iter(f"cache:{target_func.__name__}:{pattern}")
)
else:
# Clear all cache keys
keys = list(redis.scan_iter("cache:*", count=100))
keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
if keys:
pipeline = redis.pipeline()
@@ -356,8 +354,7 @@ def cached(
def cache_info() -> dict[str, int | None]:
if shared_cache:
redis = _get_redis_client()
cache_keys = list(redis.scan_iter("cache:*"))
cache_keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
return {
"size": len(cache_keys),
"maxsize": None, # Redis manages its own size
@@ -374,8 +371,7 @@ def cached(
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if shared_cache:
redis = _get_redis_client()
redis_key = _make_redis_key(key)
redis_key = _make_redis_key(key, target_func.__name__)
if redis.exists(redis_key):
redis.delete(redis_key)
return True

View File

@@ -12,7 +12,7 @@ import asyncio
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch
from unittest.mock import Mock
import pytest
@@ -677,330 +677,447 @@ class TestCache:
class TestSharedCache:
"""Tests for shared_cache functionality using Redis."""
@pytest.fixture(autouse=True)
def setup_redis_mock(self):
"""Mock Redis client for testing."""
with patch("backend.util.cache._get_redis_client") as mock_redis_func:
# Configure mock to behave like Redis
mock_redis = Mock()
self.mock_redis = mock_redis
self.redis_storage = {}
def mock_get(key):
return self.redis_storage.get(key)
def mock_getex(key, ex=None):
# GETEX returns value and optionally refreshes TTL
return self.redis_storage.get(key)
def mock_set(key, value):
self.redis_storage[key] = value
return True
def mock_setex(key, ttl, value):
self.redis_storage[key] = value
return True
def mock_exists(key):
return 1 if key in self.redis_storage else 0
def mock_delete(key):
if key in self.redis_storage:
del self.redis_storage[key]
return 1
return 0
def mock_scan_iter(pattern, count=None):
# Pattern is a string like "cache:*", keys in storage are strings
prefix = pattern.rstrip("*")
return [
k
for k in self.redis_storage.keys()
if isinstance(k, str) and k.startswith(prefix)
]
def mock_pipeline():
pipe = Mock()
deleted_keys = []
def pipe_delete(key):
deleted_keys.append(key)
return pipe
def pipe_execute():
# Actually delete the keys when pipeline executes
for key in deleted_keys:
self.redis_storage.pop(key, None)
deleted_keys.clear()
return []
pipe.delete = Mock(side_effect=pipe_delete)
pipe.execute = Mock(side_effect=pipe_execute)
return pipe
mock_redis.get = Mock(side_effect=mock_get)
mock_redis.getex = Mock(side_effect=mock_getex)
mock_redis.set = Mock(side_effect=mock_set)
mock_redis.setex = Mock(side_effect=mock_setex)
mock_redis.exists = Mock(side_effect=mock_exists)
mock_redis.delete = Mock(side_effect=mock_delete)
mock_redis.scan_iter = Mock(side_effect=mock_scan_iter)
mock_redis.pipeline = Mock(side_effect=mock_pipeline)
# Make _get_redis_client return the mock
mock_redis_func.return_value = mock_redis
yield mock_redis
# Cleanup
self.redis_storage.clear()
"""Tests for shared_cache (Redis-backed) functionality."""
def test_sync_shared_cache_basic(self):
"""Test basic shared cache functionality with sync function."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=300)
def shared_function(x: int) -> int:
@cached(ttl_seconds=30, shared_cache=True)
def shared_sync_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
return x * 10
return x + y
# First call - should miss cache
result1 = shared_function(5)
assert result1 == 50
# Clear any existing cache
shared_sync_function.cache_clear()
# First call
result1 = shared_sync_function(10, 20)
assert result1 == 30
assert call_count == 1
assert self.mock_redis.get.called
assert self.mock_redis.setex.called # setex is used for TTL
# Second call - should hit cache
result2 = shared_function(5)
assert result2 == 50
assert call_count == 1 # Function not called again
# Second call - should use Redis cache
result2 = shared_sync_function(10, 20)
assert result2 == 30
assert call_count == 1
# Different args - should call function again
result3 = shared_sync_function(15, 25)
assert result3 == 40
assert call_count == 2
# Cleanup
shared_sync_function.cache_clear()
@pytest.mark.asyncio
async def test_async_shared_cache_basic(self):
"""Test basic shared cache functionality with async function."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=300)
async def async_shared_function(x: int) -> int:
@cached(ttl_seconds=30, shared_cache=True)
async def shared_async_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 20
return x + y
# First call - should miss cache
result1 = await async_shared_function(3)
assert result1 == 60
assert call_count == 1
assert self.mock_redis.get.called
assert self.mock_redis.setex.called # setex is used for TTL
# Second call - should hit cache
result2 = await async_shared_function(3)
assert result2 == 60
assert call_count == 1 # Function not called again
def test_sync_shared_cache_with_ttl(self):
"""Test shared cache with TTL using sync function."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=60)
def shared_ttl_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 30
# Clear any existing cache
shared_async_function.cache_clear()
# First call
result1 = shared_ttl_function(2)
assert result1 == 60
assert call_count == 1
assert self.mock_redis.setex.called
# Second call - should use cache
result2 = shared_ttl_function(2)
assert result2 == 60
result1 = await shared_async_function(10, 20)
assert result1 == 30
assert call_count == 1
@pytest.mark.asyncio
async def test_async_shared_cache_with_ttl(self):
"""Test shared cache with TTL using async function."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=120)
async def async_shared_ttl_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 40
# First call
result1 = await async_shared_ttl_function(4)
assert result1 == 160
assert call_count == 1
assert self.mock_redis.setex.called
# Second call - should use cache
result2 = await async_shared_ttl_function(4)
assert result2 == 160
# Second call - should use Redis cache
result2 = await shared_async_function(10, 20)
assert result2 == 30
assert call_count == 1
def test_shared_cache_clear(self):
"""Test clearing shared cache."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=300)
def clearable_shared_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 50
# First call
result1 = clearable_shared_function(1)
assert result1 == 50
assert call_count == 1
# Second call - should use cache
result2 = clearable_shared_function(1)
assert result2 == 50
assert call_count == 1
# Clear cache
clearable_shared_function.cache_clear()
assert self.mock_redis.pipeline.called
# Third call - should execute function again
result3 = clearable_shared_function(1)
assert result3 == 50
# Different args - should call function again
result3 = await shared_async_function(15, 25)
assert result3 == 40
assert call_count == 2
# Cleanup
shared_async_function.cache_clear()
def test_shared_cache_ttl_refresh(self):
"""Test TTL refresh functionality with shared cache."""
call_count = 0
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=True)
def ttl_refresh_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 10
# Clear any existing cache
ttl_refresh_function.cache_clear()
# First call
result1 = ttl_refresh_function(3)
assert result1 == 30
assert call_count == 1
# Wait 1 second
time.sleep(1)
# Second call - should refresh TTL and use cache
result2 = ttl_refresh_function(3)
assert result2 == 30
assert call_count == 1
# Wait another 1.5 seconds (total 2.5s from first call, 1.5s from second)
time.sleep(1.5)
# Third call - TTL should have been refreshed, so still cached
result3 = ttl_refresh_function(3)
assert result3 == 30
assert call_count == 1
# Wait 2.1 seconds - now it should expire
time.sleep(2.1)
# Fourth call - should call function again
result4 = ttl_refresh_function(3)
assert result4 == 30
assert call_count == 2
# Cleanup
ttl_refresh_function.cache_clear()
def test_shared_cache_without_ttl_refresh(self):
"""Test that TTL doesn't refresh when refresh_ttl_on_get=False."""
call_count = 0
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=False)
def no_ttl_refresh_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 10
# Clear any existing cache
no_ttl_refresh_function.cache_clear()
# First call
result1 = no_ttl_refresh_function(4)
assert result1 == 40
assert call_count == 1
# Wait 1 second
time.sleep(1)
# Second call - should use cache but NOT refresh TTL
result2 = no_ttl_refresh_function(4)
assert result2 == 40
assert call_count == 1
# Wait another 1.1 seconds (total 2.1s from first call)
time.sleep(1.1)
# Third call - should have expired
result3 = no_ttl_refresh_function(4)
assert result3 == 40
assert call_count == 2
# Cleanup
no_ttl_refresh_function.cache_clear()
def test_shared_cache_complex_objects(self):
"""Test caching complex objects with shared cache (pickle serialization)."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def complex_object_function(x: int) -> dict:
nonlocal call_count
call_count += 1
return {
"number": x,
"squared": x**2,
"nested": {"list": [1, 2, x], "tuple": (x, x * 2)},
"string": f"value_{x}",
}
# Clear any existing cache
complex_object_function.cache_clear()
# First call
result1 = complex_object_function(5)
assert result1["number"] == 5
assert result1["squared"] == 25
assert result1["nested"]["list"] == [1, 2, 5]
assert call_count == 1
# Second call - should use cache
result2 = complex_object_function(5)
assert result2 == result1
assert call_count == 1
# Cleanup
complex_object_function.cache_clear()
def test_shared_cache_info(self):
"""Test cache_info for shared cache."""
@cached(ttl_seconds=30, shared_cache=True)
def info_shared_function(x: int) -> int:
return x * 2
# Clear any existing cache
info_shared_function.cache_clear()
# Check initial info
info = info_shared_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] is None # Redis manages size
assert info["ttl_seconds"] == 30
# Add some entries
info_shared_function(1)
info_shared_function(2)
info_shared_function(3)
info = info_shared_function.cache_info()
assert info["size"] == 3
# Cleanup
info_shared_function.cache_clear()
def test_shared_cache_delete(self):
"""Test deleting specific shared cache entry."""
"""Test selective deletion with shared cache."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=300)
def deletable_shared_function(x: int) -> int:
@cached(ttl_seconds=30, shared_cache=True)
def delete_shared_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 60
return x * 3
# First call for x=1
result1 = deletable_shared_function(1)
assert result1 == 60
assert call_count == 1
# Clear any existing cache
delete_shared_function.cache_clear()
# First call for x=2
result2 = deletable_shared_function(2)
assert result2 == 120
assert call_count == 2
# Add entries
delete_shared_function(1)
delete_shared_function(2)
delete_shared_function(3)
assert call_count == 3
# Delete entry for x=1
was_deleted = deletable_shared_function.cache_delete(1)
# Verify cached
delete_shared_function(1)
delete_shared_function(2)
assert call_count == 3
# Delete specific entry
was_deleted = delete_shared_function.cache_delete(2)
assert was_deleted is True
# Call with x=1 should execute function again
result3 = deletable_shared_function(1)
assert result3 == 60
assert call_count == 3
# Entry for x=2 should be gone
delete_shared_function(2)
assert call_count == 4
# Call with x=2 should still use cache
result4 = deletable_shared_function(2)
assert result4 == 120
assert call_count == 3
# Others should still be cached
delete_shared_function(1)
delete_shared_function(3)
assert call_count == 4
def test_shared_cache_error_handling(self):
"""Test that Redis errors are handled gracefully."""
call_count = 0
# Try to delete non-existent
was_deleted = delete_shared_function.cache_delete(99)
assert was_deleted is False
@cached(shared_cache=True, ttl_seconds=300)
def error_prone_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 70
# Simulate Redis error
self.mock_redis.get.side_effect = Exception("Redis connection error")
# Function should still work
result = error_prone_function(1)
assert result == 70
assert call_count == 1
# Cleanup
delete_shared_function.cache_clear()
@pytest.mark.asyncio
async def test_async_shared_cache_error_handling(self):
"""Test that Redis errors are handled gracefully in async functions."""
async def test_async_shared_cache_thundering_herd(self):
"""Test that shared cache prevents thundering herd for async functions."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=300)
async def async_error_prone_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 80
# Simulate Redis error
self.mock_redis.get.side_effect = Exception("Redis connection error")
# Function should still work
result = await async_error_prone_function(1)
assert result == 80
assert call_count == 1
def test_shared_cache_with_complex_types(self):
"""Test shared cache with complex return types (lists, dicts)."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=300)
def complex_return_function(x: int) -> dict:
nonlocal call_count
call_count += 1
return {"value": x, "squared": x * x, "list": [1, 2, 3]}
# First call
result1 = complex_return_function(5)
assert result1 == {"value": 5, "squared": 25, "list": [1, 2, 3]}
assert call_count == 1
# Second call - should use cache
result2 = complex_return_function(5)
assert result2 == {"value": 5, "squared": 25, "list": [1, 2, 3]}
assert call_count == 1
@pytest.mark.asyncio
async def test_async_thundering_herd_shared_cache(self):
"""Test thundering herd protection with shared cache."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=300)
async def slow_shared_function(x: int) -> int:
@cached(ttl_seconds=30, shared_cache=True)
async def shared_slow_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1)
return x * x
# Launch concurrent coroutines
tasks = [slow_shared_function(9) for _ in range(5)]
# Clear any existing cache
shared_slow_function.cache_clear()
# Launch multiple concurrent tasks
tasks = [shared_slow_function(8) for _ in range(10)]
results = await asyncio.gather(*tasks)
# All results should be the same
assert all(result == 81 for result in results)
# Only one coroutine should have executed the function
# All should return same result
assert all(r == 64 for r in results)
# Only one should have executed
assert call_count == 1
def test_shared_cache_info(self):
"""Test cache_info with shared cache."""
# Cleanup
shared_slow_function.cache_clear()
@cached(shared_cache=True, maxsize=100, ttl_seconds=300)
def info_function(x: int) -> int:
return x * 90
def test_shared_cache_clear_pattern(self):
"""Test pattern-based cache clearing (Redis feature)."""
# Call the function to populate cache
info_function(1)
@cached(ttl_seconds=30, shared_cache=True)
def pattern_function(category: str, item: int) -> str:
return f"{category}_{item}"
# Get cache info
info = info_function.cache_info()
assert "size" in info
assert info["maxsize"] is None # Redis manages its own size
assert info["ttl_seconds"] == 300
# Clear any existing cache
pattern_function.cache_clear()
# Add various entries
pattern_function("fruit", 1)
pattern_function("fruit", 2)
pattern_function("vegetable", 1)
pattern_function("vegetable", 2)
info = pattern_function.cache_info()
assert info["size"] == 4
# Note: Pattern clearing with wildcards requires specific Redis scan
# implementation. The current code clears by pattern but needs
# adjustment for partial matching. For now, test full clear.
pattern_function.cache_clear()
info = pattern_function.cache_info()
assert info["size"] == 0
def test_shared_vs_local_cache_isolation(self):
"""Test that shared and local caches are isolated."""
shared_count = 0
local_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def shared_function(x: int) -> int:
nonlocal shared_count
shared_count += 1
return x * 2
@cached(ttl_seconds=30, shared_cache=False)
def local_function(x: int) -> int:
nonlocal local_count
local_count += 1
return x * 2
# Clear caches
shared_function.cache_clear()
local_function.cache_clear()
# Call both with same args
shared_result = shared_function(5)
local_result = local_function(5)
assert shared_result == local_result == 10
assert shared_count == 1
assert local_count == 1
# Call again - both should use their respective caches
shared_function(5)
local_function(5)
assert shared_count == 1
assert local_count == 1
# Clear only shared cache
shared_function.cache_clear()
# Shared should recompute, local should still use cache
shared_function(5)
local_function(5)
assert shared_count == 2
assert local_count == 1
# Cleanup
shared_function.cache_clear()
local_function.cache_clear()
@pytest.mark.asyncio
async def test_shared_cache_concurrent_different_keys(self):
"""Test that concurrent calls with different keys work correctly."""
call_counts = {}
@cached(ttl_seconds=30, shared_cache=True)
async def multi_key_function(key: str) -> str:
if key not in call_counts:
call_counts[key] = 0
call_counts[key] += 1
await asyncio.sleep(0.05)
return f"result_{key}"
# Clear cache
multi_key_function.cache_clear()
# Launch concurrent tasks with different keys
keys = ["a", "b", "c", "d", "e"]
tasks = []
for key in keys:
# Multiple calls per key
tasks.extend([multi_key_function(key) for _ in range(3)])
results = await asyncio.gather(*tasks)
# Verify results
for i, key in enumerate(keys):
expected = f"result_{key}"
# Each key appears 3 times in results
key_results = results[i * 3 : (i + 1) * 3]
assert all(r == expected for r in key_results)
# Each key should only be computed once
for key in keys:
assert call_counts[key] == 1
# Cleanup
multi_key_function.cache_clear()
def test_shared_cache_performance_comparison(self):
"""Compare performance of shared vs local cache."""
import statistics
shared_times = []
local_times = []
@cached(ttl_seconds=30, shared_cache=True)
def shared_perf_function(x: int) -> int:
time.sleep(0.01) # Simulate work
return x * 2
@cached(ttl_seconds=30, shared_cache=False)
def local_perf_function(x: int) -> int:
time.sleep(0.01) # Simulate work
return x * 2
# Clear caches
shared_perf_function.cache_clear()
local_perf_function.cache_clear()
# Warm up both caches
for i in range(5):
shared_perf_function(i)
local_perf_function(i)
# Measure cache hit times
for i in range(5):
# Shared cache hit
start = time.time()
shared_perf_function(i)
shared_times.append(time.time() - start)
# Local cache hit
start = time.time()
local_perf_function(i)
local_times.append(time.time() - start)
# Local cache should be faster (no Redis round-trip)
avg_shared = statistics.mean(shared_times)
avg_local = statistics.mean(local_times)
print(f"Avg shared cache hit time: {avg_shared:.6f}s")
print(f"Avg local cache hit time: {avg_local:.6f}s")
# Local should be significantly faster for cache hits
# Redis adds network latency even for cache hits
assert avg_local < avg_shared
# Cleanup
shared_perf_function.cache_clear()
local_perf_function.cache_clear()

View File

@@ -33,12 +33,14 @@ def get_database_manager_client() -> "DatabaseManagerClient":
@thread_cached
def get_database_manager_async_client() -> "DatabaseManagerAsyncClient":
def get_database_manager_async_client(
should_retry: bool = True,
) -> "DatabaseManagerAsyncClient":
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, request_retry=True)
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
@thread_cached

View File

@@ -86,3 +86,9 @@ class GraphValidationError(ValueError):
for node_id, errors in self.node_errors.items()
]
)
class DatabaseError(Exception):
"""Raised when there is an error interacting with the database"""
pass

View File

@@ -35,6 +35,12 @@ class Flag(str, Enum):
AI_ACTIVITY_STATUS = "ai-agent-execution-summary"
BETA_BLOCKS = "beta-blocks"
AGENT_ACTIVITY = "agent-activity"
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
def is_configured() -> bool:
"""Check if LaunchDarkly is configured with an SDK key."""
return bool(settings.secrets.launch_darkly_sdk_key)
def get_client() -> LDClient:

View File

@@ -1,26 +1,22 @@
import json
import logging
import re
from typing import Any, Type, TypeGuard, TypeVar, overload
from typing import Any, Type, TypeVar, overload
import jsonschema
import orjson
from fastapi.encoders import jsonable_encoder
from fastapi.encoders import jsonable_encoder as to_dict
from prisma import Json
from pydantic import BaseModel
from .truncate import truncate
from .type import type_match
logger = logging.getLogger(__name__)
# Precompiled regex to remove PostgreSQL-incompatible control characters
# Removes \u0000-\u0008, \u000B-\u000C, \u000E-\u001F, \u007F (keeps tab \u0009, newline \u000A, carriage return \u000D)
POSTGRES_CONTROL_CHARS = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]")
def to_dict(data) -> dict:
if isinstance(data, BaseModel):
data = data.model_dump()
return jsonable_encoder(data)
def dumps(
data: Any, *args: Any, indent: int | None = None, option: int = 0, **kwargs: Any
) -> str:
@@ -109,34 +105,57 @@ def validate_with_jsonschema(
return str(e)
def is_list_of_basemodels(value: object) -> TypeGuard[list[BaseModel]]:
return isinstance(value, list) and all(
isinstance(item, BaseModel) for item in value
)
def _sanitize_string(value: str) -> str:
"""Remove PostgreSQL-incompatible control characters from string."""
return POSTGRES_CONTROL_CHARS.sub("", value)
def convert_pydantic_to_json(output_data: Any) -> Any:
if isinstance(output_data, BaseModel):
return output_data.model_dump()
if is_list_of_basemodels(output_data):
return [item.model_dump() for item in output_data]
return output_data
def sanitize_json(data: Any) -> Any:
try:
# Use two-pass approach for consistent string sanitization:
# 1. First convert to basic JSON-serializable types (handles Pydantic models)
# 2. Then sanitize strings in the result
basic_result = to_dict(data)
return to_dict(basic_result, custom_encoder={str: _sanitize_string})
except Exception as e:
# Log the failure and fall back to string representation
logger.error(
"SafeJson fallback to string representation due to serialization error: %s (%s). "
"Data type: %s, Data preview: %s",
type(e).__name__,
truncate(str(e), 200),
type(data).__name__,
truncate(str(data), 100),
)
# Ultimate fallback: convert to string representation and sanitize
return _sanitize_string(str(data))
def SafeJson(data: Any) -> Json:
class SafeJson(Json):
"""
Safely serialize data and return Prisma's Json type.
Sanitizes null bytes to prevent PostgreSQL 22P05 errors.
"""
if isinstance(data, BaseModel):
json_string = data.model_dump_json(
warnings="error",
exclude_none=True,
fallback=lambda v: None,
)
else:
json_string = dumps(data, default=lambda v: None)
Sanitizes control characters to prevent PostgreSQL 22P05 errors.
# Remove PostgreSQL-incompatible control characters in single regex operation
sanitized_json = POSTGRES_CONTROL_CHARS.sub("", json_string)
return Json(json.loads(sanitized_json))
This function:
1. Converts Pydantic models to dicts (recursively using to_dict)
2. Recursively removes PostgreSQL-incompatible control characters from strings
3. Returns a Prisma Json object safe for database storage
Uses to_dict (jsonable_encoder) with a custom encoder to handle both Pydantic
conversion and control character sanitization in a two-pass approach.
Args:
data: Input data to sanitize and convert to Json
Returns:
Prisma Json object with control characters removed
Examples:
>>> SafeJson({"text": "Hello\\x00World"}) # null char removed
>>> SafeJson({"path": "C:\\\\temp"}) # backslashes preserved
>>> SafeJson({"data": "Text\\\\u0000here"}) # literal backslash-u preserved
"""
def __init__(self, data: Any):
super().__init__(sanitize_json(data))

View File

@@ -5,8 +5,10 @@ import sentry_sdk
from pydantic import SecretStr
from sentry_sdk.integrations.anthropic import AnthropicIntegration
from sentry_sdk.integrations.asyncio import AsyncioIntegration
from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
from backend.util.feature_flag import get_client, is_configured
from backend.util.settings import Settings
settings = Settings()
@@ -19,6 +21,9 @@ class DiscordChannel(str, Enum):
def sentry_init():
sentry_dsn = settings.secrets.sentry_dsn
integrations = []
if is_configured():
integrations.append(LaunchDarklyIntegration(get_client()))
sentry_sdk.init(
dsn=sentry_dsn,
traces_sample_rate=1.0,
@@ -31,7 +36,8 @@ def sentry_init():
AnthropicIntegration(
include_prompts=False,
),
],
]
+ integrations,
)

View File

@@ -8,18 +8,9 @@ from typing import Optional
from backend.util.logging import configure_logging
from backend.util.metrics import sentry_init
from backend.util.settings import set_service_name
logger = logging.getLogger(__name__)
_SERVICE_NAME = "MainProcess"
def get_service_name():
return _SERVICE_NAME
def set_service_name(name: str):
global _SERVICE_NAME
_SERVICE_NAME = name
class AppProcess(ABC):

View File

@@ -13,7 +13,7 @@ import idna
from aiohttp import FormData, abc
from tenacity import retry, retry_if_result, wait_exponential_jitter
from backend.util.json import json
from backend.util.json import loads
# Retry status codes for which we will automatically retry the request
THROTTLE_RETRY_STATUS_CODES: set[int] = {429, 500, 502, 503, 504, 408}
@@ -259,7 +259,7 @@ class Response:
"""
Parse the body as JSON and return the resulting Python object.
"""
return json.loads(
return loads(
self.content.decode(encoding or "utf-8", errors="replace"), **kwargs
)

View File

@@ -13,41 +13,80 @@ from tenacity import (
wait_exponential_jitter,
)
from backend.util.process import get_service_name
from backend.util.settings import get_service_name
logger = logging.getLogger(__name__)
# Alert threshold for excessive retries
EXCESSIVE_RETRY_THRESHOLD = 50
# Rate limiting for alerts - track last alert time per function+error combination
_alert_rate_limiter = {}
_rate_limiter_lock = threading.Lock()
ALERT_RATE_LIMIT_SECONDS = 300 # 5 minutes between same alerts
def should_send_alert(func_name: str, exception: Exception, context: str = "") -> bool:
"""Check if we should send an alert based on rate limiting."""
# Create a unique key for this function+error+context combination
error_signature = (
f"{context}:{func_name}:{type(exception).__name__}:{str(exception)[:100]}"
)
current_time = time.time()
with _rate_limiter_lock:
last_alert_time = _alert_rate_limiter.get(error_signature, 0)
if current_time - last_alert_time >= ALERT_RATE_LIMIT_SECONDS:
_alert_rate_limiter[error_signature] = current_time
return True
return False
def send_rate_limited_discord_alert(
func_name: str, exception: Exception, context: str, alert_msg: str, channel=None
) -> bool:
"""
Send a Discord alert with rate limiting.
Returns True if alert was sent, False if rate limited.
"""
if not should_send_alert(func_name, exception, context):
return False
try:
from backend.util.clients import get_notification_manager_client
from backend.util.metrics import DiscordChannel
notification_client = get_notification_manager_client()
notification_client.discord_system_alert(
alert_msg, channel or DiscordChannel.PLATFORM
)
return True
except Exception as alert_error:
logger.error(f"Failed to send Discord alert: {alert_error}")
return False
def _send_critical_retry_alert(
func_name: str, attempt_number: int, exception: Exception, context: str = ""
):
"""Send alert when a function is approaching the retry failure threshold."""
try:
# Import here to avoid circular imports
from backend.util.clients import get_notification_manager_client
notification_client = get_notification_manager_client()
prefix = f"{context}: " if context else ""
alert_msg = (
f"🚨 CRITICAL: Operation Approaching Failure Threshold: {prefix}'{func_name}'\n\n"
f"Current attempt: {attempt_number}/{EXCESSIVE_RETRY_THRESHOLD}\n"
f"Error: {type(exception).__name__}: {exception}\n\n"
f"This operation is about to fail permanently. Investigate immediately."
)
notification_client.discord_system_alert(alert_msg)
prefix = f"{context}: " if context else ""
if send_rate_limited_discord_alert(
func_name,
exception,
context,
f"🚨 CRITICAL: Operation Approaching Failure Threshold: {prefix}'{func_name}'\n\n"
f"Current attempt: {attempt_number}/{EXCESSIVE_RETRY_THRESHOLD}\n"
f"Error: {type(exception).__name__}: {exception}\n\n"
f"This operation is about to fail permanently. Investigate immediately.",
):
logger.critical(
f"CRITICAL ALERT SENT: Operation {func_name} at attempt {attempt_number}"
)
except Exception as alert_error:
logger.error(f"Failed to send critical retry alert: {alert_error}")
# Don't let alerting failures break the main flow
def _create_retry_callback(context: str = ""):
"""Create a retry callback with optional context."""
@@ -66,7 +105,7 @@ def _create_retry_callback(context: str = ""):
f"{type(exception).__name__}: {exception}"
)
else:
# Retry attempt - send critical alert only once at threshold
# Retry attempt - send critical alert only once at threshold (rate limited)
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
_send_critical_retry_alert(
func_name, attempt_number, exception, context
@@ -131,7 +170,7 @@ def _log_prefix(resource_name: str, conn_id: str):
def conn_retry(
resource_name: str,
action_name: str,
max_retry: int = 5,
max_retry: int = 100,
max_wait: float = 30,
):
conn_id = str(uuid4())
@@ -139,10 +178,29 @@ def conn_retry(
def on_retry(retry_state):
prefix = _log_prefix(resource_name, conn_id)
exception = retry_state.outcome.exception()
attempt_number = retry_state.attempt_number
func_name = getattr(retry_state.fn, "__name__", "unknown")
if retry_state.outcome.failed and retry_state.next_action is None:
logger.error(f"{prefix} {action_name} failed after retries: {exception}")
else:
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
if send_rate_limited_discord_alert(
func_name,
exception,
f"{resource_name}_infrastructure",
f"🚨 **Critical Infrastructure Connection Issue**\n"
f"Resource: {resource_name}\n"
f"Action: {action_name}\n"
f"Function: {func_name}\n"
f"Current attempt: {attempt_number}/{max_retry + 1}\n"
f"Error: {type(exception).__name__}: {str(exception)[:200]}{'...' if len(str(exception)) > 200 else ''}\n\n"
f"Infrastructure component is approaching failure threshold. Investigate immediately.",
):
logger.critical(
f"INFRASTRUCTURE ALERT SENT: {resource_name} at {attempt_number} attempts"
)
logger.warning(
f"{prefix} {action_name} failed: {exception}. Retrying now..."
)
@@ -218,8 +276,8 @@ def continuous_retry(*, retry_delay: float = 1.0):
@wraps(func)
async def async_wrapper(*args, **kwargs):
counter = 0
while True:
counter = 0
try:
return await func(*args, **kwargs)
except Exception as exc:

View File

@@ -1,8 +1,19 @@
import asyncio
import threading
import time
from unittest.mock import Mock, patch
import pytest
from backend.util.retry import conn_retry
from backend.util.retry import (
ALERT_RATE_LIMIT_SECONDS,
_alert_rate_limiter,
_rate_limiter_lock,
_send_critical_retry_alert,
conn_retry,
create_retry_decorator,
should_send_alert,
)
def test_conn_retry_sync_function():
@@ -47,3 +58,194 @@ async def test_conn_retry_async_function():
with pytest.raises(ValueError) as e:
await test_function()
assert str(e.value) == "Test error"
class TestRetryRateLimiting:
"""Test the rate limiting functionality for critical retry alerts."""
def setup_method(self):
"""Reset rate limiter state before each test."""
with _rate_limiter_lock:
_alert_rate_limiter.clear()
def test_should_send_alert_allows_first_occurrence(self):
"""Test that the first occurrence of an error allows alert."""
exc = ValueError("test error")
assert should_send_alert("test_func", exc, "test_context") is True
def test_should_send_alert_rate_limits_duplicate(self):
"""Test that duplicate errors are rate limited."""
exc = ValueError("test error")
# First call should be allowed
assert should_send_alert("test_func", exc, "test_context") is True
# Second call should be rate limited
assert should_send_alert("test_func", exc, "test_context") is False
def test_should_send_alert_allows_different_errors(self):
"""Test that different errors are allowed even if same function."""
exc1 = ValueError("error 1")
exc2 = ValueError("error 2")
# First error should be allowed
assert should_send_alert("test_func", exc1, "test_context") is True
# Different error should also be allowed
assert should_send_alert("test_func", exc2, "test_context") is True
def test_should_send_alert_allows_different_contexts(self):
"""Test that same error in different contexts is allowed."""
exc = ValueError("test error")
# First context should be allowed
assert should_send_alert("test_func", exc, "context1") is True
# Different context should also be allowed
assert should_send_alert("test_func", exc, "context2") is True
def test_should_send_alert_allows_different_functions(self):
"""Test that same error in different functions is allowed."""
exc = ValueError("test error")
# First function should be allowed
assert should_send_alert("func1", exc, "test_context") is True
# Different function should also be allowed
assert should_send_alert("func2", exc, "test_context") is True
def test_should_send_alert_respects_time_window(self):
"""Test that alerts are allowed again after the rate limit window."""
exc = ValueError("test error")
# First call should be allowed
assert should_send_alert("test_func", exc, "test_context") is True
# Immediately after should be rate limited
assert should_send_alert("test_func", exc, "test_context") is False
# Mock time to simulate passage of rate limit window
current_time = time.time()
with patch("backend.util.retry.time.time") as mock_time:
# Simulate time passing beyond rate limit window
mock_time.return_value = current_time + ALERT_RATE_LIMIT_SECONDS + 1
assert should_send_alert("test_func", exc, "test_context") is True
def test_should_send_alert_thread_safety(self):
"""Test that rate limiting is thread-safe."""
exc = ValueError("test error")
results = []
def check_alert():
result = should_send_alert("test_func", exc, "test_context")
results.append(result)
# Create multiple threads trying to send the same alert
threads = [threading.Thread(target=check_alert) for _ in range(10)]
# Start all threads
for thread in threads:
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Only one thread should have been allowed to send the alert
assert sum(results) == 1
assert len([r for r in results if r is True]) == 1
assert len([r for r in results if r is False]) == 9
@patch("backend.util.clients.get_notification_manager_client")
def test_send_critical_retry_alert_rate_limiting(self, mock_get_client):
"""Test that _send_critical_retry_alert respects rate limiting."""
mock_client = Mock()
mock_get_client.return_value = mock_client
exc = ValueError("spend_credits API error")
# First alert should be sent
_send_critical_retry_alert("spend_credits", 50, exc, "Service communication")
assert mock_client.discord_system_alert.call_count == 1
# Second identical alert should be rate limited (not sent)
_send_critical_retry_alert("spend_credits", 50, exc, "Service communication")
assert mock_client.discord_system_alert.call_count == 1 # Still 1, not 2
# Different error should be allowed
exc2 = ValueError("different API error")
_send_critical_retry_alert("spend_credits", 50, exc2, "Service communication")
assert mock_client.discord_system_alert.call_count == 2
@patch("backend.util.clients.get_notification_manager_client")
def test_send_critical_retry_alert_handles_notification_failure(
self, mock_get_client
):
"""Test that notification failures don't break the rate limiter."""
mock_client = Mock()
mock_client.discord_system_alert.side_effect = Exception("Notification failed")
mock_get_client.return_value = mock_client
exc = ValueError("test error")
# Should not raise exception even if notification fails
_send_critical_retry_alert("test_func", 50, exc, "test_context")
# Rate limiter should still work for subsequent calls
assert should_send_alert("test_func", exc, "test_context") is False
def test_error_signature_generation(self):
"""Test that error signatures are generated correctly for rate limiting."""
# Test with long exception message (should be truncated to 100 chars)
long_message = "x" * 200
exc = ValueError(long_message)
# Should not raise exception and should work normally
assert should_send_alert("test_func", exc, "test_context") is True
assert should_send_alert("test_func", exc, "test_context") is False
def test_real_world_scenario_spend_credits_spam(self):
"""Test the real-world scenario that was causing spam."""
# Simulate the exact error that was causing issues
exc = Exception(
"HTTP 500: Server error '500 Internal Server Error' for url 'http://autogpt-database-manager.prod-agpt.svc.cluster.local:8005/spend_credits'"
)
# First 50 attempts reach threshold - should send alert
with patch(
"backend.util.clients.get_notification_manager_client"
) as mock_get_client:
mock_client = Mock()
mock_get_client.return_value = mock_client
_send_critical_retry_alert(
"_call_method_sync", 50, exc, "Service communication"
)
assert mock_client.discord_system_alert.call_count == 1
# Next 950 failures should not send alerts (rate limited)
for _ in range(950):
_send_critical_retry_alert(
"_call_method_sync", 50, exc, "Service communication"
)
# Still only 1 alert sent total
assert mock_client.discord_system_alert.call_count == 1
@patch("backend.util.clients.get_notification_manager_client")
def test_retry_decorator_with_excessive_failures(self, mock_get_client):
"""Test retry decorator behavior when it hits the alert threshold."""
mock_client = Mock()
mock_get_client.return_value = mock_client
@create_retry_decorator(
max_attempts=60, max_wait=0.1
) # More than EXCESSIVE_RETRY_THRESHOLD, but fast
def always_failing_function():
raise ValueError("persistent failure")
with pytest.raises(ValueError):
always_failing_function()
# Should have sent exactly one alert at the threshold
assert mock_client.discord_system_alert.call_count == 1

View File

@@ -28,11 +28,12 @@ from fastapi import FastAPI, Request, responses
from pydantic import BaseModel, TypeAdapter, create_model
import backend.util.exceptions as exceptions
from backend.monitoring.instrumentation import instrument_fastapi
from backend.util.json import to_dict
from backend.util.metrics import sentry_init
from backend.util.process import AppProcess, get_service_name
from backend.util.process import AppProcess
from backend.util.retry import conn_retry, create_retry_decorator
from backend.util.settings import Config
from backend.util.settings import Config, get_service_name
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -283,6 +284,24 @@ class AppService(BaseAppService, ABC):
super().run()
self.fastapi_app = FastAPI()
# Add Prometheus instrumentation to all services
try:
instrument_fastapi(
self.fastapi_app,
service_name=self.service_name,
expose_endpoint=True,
endpoint="/metrics",
include_in_schema=False,
)
except ImportError:
logger.warning(
f"Prometheus instrumentation not available for {self.service_name}"
)
except Exception as e:
logger.error(
f"Failed to instrument {self.service_name} with Prometheus: {e}"
)
# Register the exposed API routes.
for attr_name, attr in vars(type(self)).items():
if getattr(attr, EXPOSED_FLAG, False):

View File

@@ -15,6 +15,17 @@ from backend.util.data import get_data_path
T = TypeVar("T", bound=BaseSettings)
_SERVICE_NAME = "MainProcess"
def get_service_name():
return _SERVICE_NAME
def set_service_name(name: str):
global _SERVICE_NAME
_SERVICE_NAME = name
class AppEnvironment(str, Enum):
LOCAL = "local"
@@ -148,6 +159,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=60 * 60,
description="Time in seconds for how far back to check for the late executions.",
)
max_concurrent_graph_executions_per_user: int = Field(
default=25,
ge=1,
le=1000,
description="Maximum number of concurrent graph executions allowed per user per graph.",
)
block_error_rate_threshold: float = Field(
default=0.5,
@@ -248,6 +265,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default="localhost",
description="The host for the RabbitMQ server",
)
rabbitmq_port: int = Field(
default=5672,
description="The port for the RabbitMQ server",
@@ -262,10 +280,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default="localhost",
description="The host for the Redis server",
)
redis_port: int = Field(
default=6379,
description="The port for the Redis server",
)
redis_password: str = Field(
default="",
description="The password for the Redis server (empty string if no password)",

View File

@@ -1,5 +1,5 @@
import datetime
from typing import Any, Optional
from typing import Any, Optional, cast
from prisma import Json
from pydantic import BaseModel
@@ -231,6 +231,14 @@ class TestSafeJson:
result = SafeJson(problematic_data)
assert isinstance(result, Json)
# Verify that dangerous control characters are actually removed
result_data = result.data
assert "\x00" not in str(result_data) # null byte removed
assert "\x07" not in str(result_data) # bell removed
assert "\x0C" not in str(result_data) # form feed removed
assert "\x1B" not in str(result_data) # escape removed
assert "\x7F" not in str(result_data) # delete removed
# Test that safe whitespace characters are preserved
safe_data = {
"with_tab": "text with \t tab",
@@ -241,3 +249,508 @@ class TestSafeJson:
safe_result = SafeJson(safe_data)
assert isinstance(safe_result, Json)
# Verify safe characters are preserved
safe_result_data = cast(dict[str, Any], safe_result.data)
assert isinstance(safe_result_data, dict)
with_tab = safe_result_data.get("with_tab", "")
with_newline = safe_result_data.get("with_newline", "")
with_carriage_return = safe_result_data.get("with_carriage_return", "")
assert "\t" in str(with_tab) # tab preserved
assert "\n" in str(with_newline) # newline preserved
assert "\r" in str(with_carriage_return) # carriage return preserved
def test_web_scraping_content_sanitization(self):
"""Test sanitization of typical web scraping content with null characters."""
# Simulate web content that might contain null bytes from SearchTheWebBlock
web_content = "Article title\x00Hidden null\x01Start of heading\x08Backspace\x0CForm feed content\x1FUnit separator\x7FDelete char"
result = SafeJson(web_content)
assert isinstance(result, Json)
# Verify all problematic characters are removed
sanitized_content = str(result.data)
assert "\x00" not in sanitized_content
assert "\x01" not in sanitized_content
assert "\x08" not in sanitized_content
assert "\x0C" not in sanitized_content
assert "\x1F" not in sanitized_content
assert "\x7F" not in sanitized_content
# Verify the content is still readable
assert "Article title" in sanitized_content
assert "Hidden null" in sanitized_content
assert "content" in sanitized_content
def test_legitimate_code_preservation(self):
"""Test that legitimate code with backslashes and escapes is preserved."""
# File paths with backslashes should be preserved
file_paths = {
"windows_path": "C:\\Users\\test\\file.txt",
"network_path": "\\\\server\\share\\folder",
"escaped_backslashes": "String with \\\\ double backslashes",
}
result = SafeJson(file_paths)
result_data = cast(dict[str, Any], result.data)
assert isinstance(result_data, dict)
# Verify file paths are preserved correctly (JSON converts \\\\ back to \\)
windows_path = result_data.get("windows_path", "")
network_path = result_data.get("network_path", "")
escaped_backslashes = result_data.get("escaped_backslashes", "")
assert "C:\\Users\\test\\file.txt" in str(windows_path)
assert "\\server\\share" in str(network_path)
assert "\\" in str(escaped_backslashes)
def test_legitimate_json_escapes_preservation(self):
"""Test that legitimate JSON escape sequences are preserved."""
# These should all be preserved as they're valid and useful
legitimate_escapes = {
"quotes": 'He said "Hello world!"',
"newlines": "Line 1\\nLine 2\\nLine 3",
"tabs": "Column1\\tColumn2\\tColumn3",
"unicode_chars": "Unicode: \u0048\u0065\u006c\u006c\u006f", # "Hello"
"mixed_content": "Path: C:\\\\temp\\\\file.txt\\nSize: 1024 bytes",
}
result = SafeJson(legitimate_escapes)
result_data = cast(dict[str, Any], result.data)
assert isinstance(result_data, dict)
# Verify all legitimate content is preserved
quotes = result_data.get("quotes", "")
newlines = result_data.get("newlines", "")
tabs = result_data.get("tabs", "")
unicode_chars = result_data.get("unicode_chars", "")
mixed_content = result_data.get("mixed_content", "")
assert '"' in str(quotes)
assert "Line 1" in str(newlines) and "Line 2" in str(newlines)
assert "Column1" in str(tabs) and "Column2" in str(tabs)
assert "Hello" in str(unicode_chars) # Unicode should be decoded
assert "C:" in str(mixed_content) and "temp" in str(mixed_content)
def test_regex_patterns_dont_over_match(self):
"""Test that our regex patterns don't accidentally match legitimate sequences."""
# Edge cases that could be problematic for regex
edge_cases = {
"file_with_b": "C:\\\\mybfile.txt", # Contains 'bf' but not escape sequence
"file_with_f": "C:\\\\folder\\\\file.txt", # Contains 'f' after backslashes
"json_like_string": '{"text": "\\\\bolder text"}', # Looks like JSON escape but isn't
"unicode_like": "Code: \\\\u0040 (not a real escape)", # Looks like Unicode escape
}
result = SafeJson(edge_cases)
result_data = cast(dict[str, Any], result.data)
assert isinstance(result_data, dict)
# Verify edge cases are handled correctly - no content should be lost
file_with_b = result_data.get("file_with_b", "")
file_with_f = result_data.get("file_with_f", "")
json_like_string = result_data.get("json_like_string", "")
unicode_like = result_data.get("unicode_like", "")
assert "mybfile.txt" in str(file_with_b)
assert "folder" in str(file_with_f) and "file.txt" in str(file_with_f)
assert "bolder text" in str(json_like_string)
assert "\\u0040" in str(unicode_like)
def test_programming_code_preservation(self):
"""Test that programming code with various escapes is preserved."""
# Common programming patterns that should be preserved
code_samples = {
"python_string": 'print("Hello\\\\nworld")',
"regex_pattern": "\\\\b[A-Za-z]+\\\\b", # Word boundary regex
"json_string": '{"name": "test", "path": "C:\\\\\\\\folder"}',
"sql_escape": "WHERE name LIKE '%\\\\%%'",
"javascript": 'var path = "C:\\\\\\\\Users\\\\\\\\file.js";',
}
result = SafeJson(code_samples)
result_data = cast(dict[str, Any], result.data)
assert isinstance(result_data, dict)
# Verify programming code is preserved
python_string = result_data.get("python_string", "")
regex_pattern = result_data.get("regex_pattern", "")
json_string = result_data.get("json_string", "")
sql_escape = result_data.get("sql_escape", "")
javascript = result_data.get("javascript", "")
assert "print(" in str(python_string)
assert "Hello" in str(python_string)
assert "[A-Za-z]+" in str(regex_pattern)
assert "name" in str(json_string)
assert "LIKE" in str(sql_escape)
assert "var path" in str(javascript)
def test_only_problematic_sequences_removed(self):
"""Test that ONLY PostgreSQL-problematic sequences are removed, nothing else."""
# Mix of problematic and safe content (using actual control characters)
mixed_content = {
"safe_and_unsafe": "Good text\twith tab\x00NULL BYTE\nand newline\x08BACKSPACE",
"file_path_with_null": "C:\\temp\\file\x00.txt",
"json_with_controls": '{"text": "data\x01\x0C\x1F"}',
}
result = SafeJson(mixed_content)
result_data = cast(dict[str, Any], result.data)
assert isinstance(result_data, dict)
# Verify only problematic characters are removed
safe_and_unsafe = result_data.get("safe_and_unsafe", "")
file_path_with_null = result_data.get("file_path_with_null", "")
assert "Good text" in str(safe_and_unsafe)
assert "\t" in str(safe_and_unsafe) # Tab preserved
assert "\n" in str(safe_and_unsafe) # Newline preserved
assert "\x00" not in str(safe_and_unsafe) # Null removed
assert "\x08" not in str(safe_and_unsafe) # Backspace removed
assert "C:\\temp\\file" in str(file_path_with_null)
assert ".txt" in str(file_path_with_null)
assert "\x00" not in str(file_path_with_null) # Null removed from path
def test_invalid_escape_error_prevention(self):
"""Test that SafeJson prevents 'Invalid \\escape' errors that occurred in upsert_execution_output."""
# This reproduces the exact scenario that was causing the error:
# POST /upsert_execution_output failed: Invalid \escape: line 1 column 36404 (char 36403)
# Create data with various problematic escape sequences that could cause JSON parsing errors
problematic_output_data = {
"web_content": "Article text\x00with null\x01and control\x08chars\x0C\x1F\x7F",
"file_path": "C:\\Users\\test\\file\x00.txt",
"json_like_string": '{"text": "data\x00\x08\x1F"}',
"escaped_sequences": "Text with \\u0000 and \\u0008 sequences",
"mixed_content": "Normal text\tproperly\nformatted\rwith\x00invalid\x08chars\x1Fmixed",
"large_text": "A" * 35000
+ "\x00\x08\x1F"
+ "B" * 5000, # Large text like in the error
}
# This should not raise any JSON parsing errors
result = SafeJson(problematic_output_data)
assert isinstance(result, Json)
# Verify the result is a valid Json object that can be safely stored in PostgreSQL
result_data = cast(dict[str, Any], result.data)
assert isinstance(result_data, dict)
# Verify problematic characters are removed but safe content preserved
web_content = result_data.get("web_content", "")
file_path = result_data.get("file_path", "")
large_text = result_data.get("large_text", "")
# Check that control characters are removed
assert "\x00" not in str(web_content)
assert "\x01" not in str(web_content)
assert "\x08" not in str(web_content)
assert "\x0C" not in str(web_content)
assert "\x1F" not in str(web_content)
assert "\x7F" not in str(web_content)
# Check that legitimate content is preserved
assert "Article text" in str(web_content)
assert "with null" in str(web_content)
assert "and control" in str(web_content)
assert "chars" in str(web_content)
# Check file path handling
assert "C:\\Users\\test\\file" in str(file_path)
assert ".txt" in str(file_path)
assert "\x00" not in str(file_path)
# Check large text handling (the scenario from the error at char 36403)
assert len(str(large_text)) > 35000 # Content preserved
assert "A" * 1000 in str(large_text) # A's preserved
assert "B" * 1000 in str(large_text) # B's preserved
assert "\x00" not in str(large_text) # Control chars removed
assert "\x08" not in str(large_text)
assert "\x1F" not in str(large_text)
# Most importantly: ensure the result can be JSON-serialized without errors
# This would have failed with the old approach
import json
json_string = json.dumps(result.data) # Should not raise "Invalid \escape"
assert len(json_string) > 0
# And can be parsed back
parsed_back = json.loads(json_string)
assert isinstance(parsed_back, dict)
def test_dict_containing_pydantic_models(self):
"""Test that dicts containing Pydantic models are properly serialized."""
# This reproduces the bug from PR #11187 where credential_inputs failed
model1 = SamplePydanticModel(name="Alice", age=30)
model2 = SamplePydanticModel(name="Bob", age=25)
data = {
"user1": model1,
"user2": model2,
"regular_data": "test",
}
result = SafeJson(data)
assert isinstance(result, Json)
# Verify it can be JSON serialized (this was the bug)
import json
json_string = json.dumps(result.data)
assert "Alice" in json_string
assert "Bob" in json_string
def test_nested_pydantic_in_dict(self):
"""Test deeply nested Pydantic models in dicts."""
inner_model = SamplePydanticModel(name="Inner", age=20)
middle_model = SamplePydanticModel(
name="Middle", age=30, metadata={"inner": inner_model}
)
data = {
"level1": {
"level2": {
"model": middle_model,
"other": "data",
}
}
}
result = SafeJson(data)
assert isinstance(result, Json)
import json
json_string = json.dumps(result.data)
assert "Middle" in json_string
assert "Inner" in json_string
def test_list_containing_pydantic_models_in_dict(self):
"""Test list of Pydantic models inside a dict."""
models = [SamplePydanticModel(name=f"User{i}", age=20 + i) for i in range(5)]
data = {
"users": models,
"count": len(models),
}
result = SafeJson(data)
assert isinstance(result, Json)
import json
json_string = json.dumps(result.data)
assert "User0" in json_string
assert "User4" in json_string
def test_credentials_meta_input_scenario(self):
"""Test the exact scenario from create_graph_execution that was failing."""
# Simulate CredentialsMetaInput structure
class MockCredentialsMetaInput(BaseModel):
id: str
title: Optional[str] = None
provider: str
type: str
cred_input = MockCredentialsMetaInput(
id="test-123", title="Test Credentials", provider="github", type="oauth2"
)
# This is how credential_inputs is structured in create_graph_execution
credential_inputs = {"github_creds": cred_input}
# This should work without TypeError
result = SafeJson(credential_inputs)
assert isinstance(result, Json)
# Verify it can be JSON serialized
import json
json_string = json.dumps(result.data)
assert "test-123" in json_string
assert "github" in json_string
assert "oauth2" in json_string
def test_mixed_pydantic_and_primitives(self):
"""Test complex mix of Pydantic models and primitive types."""
model = SamplePydanticModel(name="Test", age=25)
data = {
"models": [model, {"plain": "dict"}, "string", 123],
"nested": {
"model": model,
"list": [1, 2, model, 4],
"plain": "text",
},
"plain_list": [1, 2, 3],
}
result = SafeJson(data)
assert isinstance(result, Json)
import json
json_string = json.dumps(result.data)
assert "Test" in json_string
assert "plain" in json_string
def test_pydantic_model_with_control_chars_in_dict(self):
"""Test Pydantic model with control chars when nested in dict."""
model = SamplePydanticModel(
name="Test\x00User", # Has null byte
age=30,
metadata={"info": "data\x08with\x0Ccontrols"},
)
data = {"credential": model}
result = SafeJson(data)
assert isinstance(result, Json)
# Verify control characters are removed
import json
json_string = json.dumps(result.data)
assert "\x00" not in json_string
assert "\x08" not in json_string
assert "\x0C" not in json_string
assert "TestUser" in json_string # Name preserved minus null byte
def test_deeply_nested_pydantic_models_control_char_sanitization(self):
"""Test that control characters are sanitized in deeply nested Pydantic models."""
# Create nested Pydantic models with control characters at different levels
class InnerModel(BaseModel):
deep_string: str
value: int = 42
metadata: dict = {}
class MiddleModel(BaseModel):
middle_string: str
inner: InnerModel
data: str
class OuterModel(BaseModel):
outer_string: str
middle: MiddleModel
# Create test data with control characters at every nesting level
inner = InnerModel(
deep_string="Deepest\x00Level\x08Control\x0CChars", # Multiple control chars at deepest level
metadata={
"nested_key": "Nested\x1FValue\x7FDelete"
}, # Control chars in nested dict
)
middle = MiddleModel(
middle_string="Middle\x01StartOfHeading\x1FUnitSeparator",
inner=inner,
data="Some\x0BVerticalTab\x0EShiftOut",
)
outer = OuterModel(outer_string="Outer\x00Null\x07Bell", middle=middle)
# Wrap in a dict with additional control characters
data = {
"top_level": "Top\x00Level\x08Backspace",
"nested_model": outer,
"list_with_strings": [
"List\x00Item1",
"List\x0CItem2\x1F",
{"dict_in_list": "Dict\x08Value"},
],
}
# Process with SafeJson
result = SafeJson(data)
assert isinstance(result, Json)
# Verify all control characters are removed at every level
import json
json_string = json.dumps(result.data)
# Check that NO control characters remain anywhere
control_chars = [
"\x00",
"\x01",
"\x02",
"\x03",
"\x04",
"\x05",
"\x06",
"\x07",
"\x08",
"\x0B",
"\x0C",
"\x0E",
"\x0F",
"\x10",
"\x11",
"\x12",
"\x13",
"\x14",
"\x15",
"\x16",
"\x17",
"\x18",
"\x19",
"\x1A",
"\x1B",
"\x1C",
"\x1D",
"\x1E",
"\x1F",
"\x7F",
]
for char in control_chars:
assert (
char not in json_string
), f"Control character {repr(char)} found in result"
# Verify specific sanitized content is present (control chars removed but text preserved)
result_data = cast(dict[str, Any], result.data)
# Top level
assert "TopLevelBackspace" in json_string
# Outer model level
assert "OuterNullBell" in json_string
# Middle model level
assert "MiddleStartOfHeadingUnitSeparator" in json_string
assert "SomeVerticalTabShiftOut" in json_string
# Inner model level (deepest nesting)
assert "DeepestLevelControlChars" in json_string
# Nested dict in model
assert "NestedValueDelete" in json_string
# List items
assert "ListItem1" in json_string
assert "ListItem2" in json_string
assert "DictValue" in json_string
# Verify structure is preserved (not just converted to string)
assert isinstance(result_data, dict)
assert isinstance(result_data["nested_model"], dict)
assert isinstance(result_data["nested_model"]["middle"], dict)
assert isinstance(result_data["nested_model"]["middle"]["inner"], dict)
assert isinstance(result_data["list_with_strings"], list)
# Verify specific deep values are accessible and sanitized
nested_model = cast(dict[str, Any], result_data["nested_model"])
middle = cast(dict[str, Any], nested_model["middle"])
inner = cast(dict[str, Any], middle["inner"])
deep_string = inner["deep_string"]
assert deep_string == "DeepestLevelControlChars"
metadata = cast(dict[str, Any], inner["metadata"])
nested_metadata = metadata["nested_key"]
assert nested_metadata == "NestedValueDelete"

View File

@@ -0,0 +1,62 @@
BEGIN;
-- Drop and recreate the StoreAgent view with isAvailable field
DROP VIEW IF EXISTS "StoreAgent";
CREATE OR REPLACE VIEW "StoreAgent" AS
WITH latest_versions AS (
SELECT
"storeListingId",
MAX(version) AS max_version
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
),
agent_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
)
SELECT
sl.id AS listing_id,
slv.id AS "storeListingVersionId",
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
p.username AS creator_username, -- Allow NULL for malformed sub-agents
p."avatarUrl" AS creator_avatar, -- Allow NULL for malformed sub-agents
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
slv."isAvailable" AS is_available -- Add isAvailable field to filter sub-agents
FROM "StoreListing" sl
JOIN latest_versions lv
ON sl.id = lv."storeListingId"
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = lv."storeListingId"
AND slv.version = lv.max_version
AND slv."submissionStatus" = 'APPROVED'
JOIN "AgentGraph" a
ON slv."agentGraphId" = a.id
AND slv."agentGraphVersion" = a.version
LEFT JOIN "Profile" p
ON sl."owningUserId" = p."userId"
LEFT JOIN "mv_review_stats" rs
ON sl.id = rs."storeListingId"
LEFT JOIN "mv_agent_run_counts" ar
ON a.id = ar."agentGraphId"
LEFT JOIN agent_versions av
ON sl.id = av."storeListingId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true;
COMMIT;

View File

@@ -0,0 +1,11 @@
-- CreateTable
CREATE TABLE "SearchTerms" (
"id" BIGSERIAL NOT NULL,
"createdDate" TIMESTAMP(3) NOT NULL,
"searchTerm" TEXT NOT NULL,
CONSTRAINT "SearchTerms_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "SearchTerms_createdDate_idx" ON "SearchTerms"("createdDate");

View File

@@ -0,0 +1,16 @@
-- Create UserBalance table for atomic credit operations
-- This replaces the need for User.balance column and provides better separation of concerns
-- UserBalance records are automatically created by the application when users interact with the credit system
-- CreateTable (only if it doesn't exist)
CREATE TABLE IF NOT EXISTS "UserBalance" (
"userId" TEXT NOT NULL,
"balance" INTEGER NOT NULL DEFAULT 0,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "UserBalance_pkey" PRIMARY KEY ("userId"),
CONSTRAINT "UserBalance_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE
);
-- CreateIndex (only if it doesn't exist)
CREATE INDEX IF NOT EXISTS "UserBalance_userId_idx" ON "UserBalance"("userId");

View File

@@ -0,0 +1,65 @@
BEGIN;
-- AlterTable
ALTER TABLE "StoreListing" ADD COLUMN "useForOnboarding" BOOLEAN NOT NULL DEFAULT false;
-- Drop and recreate the StoreAgent view with useForOnboarding field
DROP VIEW IF EXISTS "StoreAgent";
CREATE OR REPLACE VIEW "StoreAgent" AS
WITH latest_versions AS (
SELECT
"storeListingId",
MAX(version) AS max_version
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
),
agent_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
)
SELECT
sl.id AS listing_id,
slv.id AS "storeListingVersionId",
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
p.username AS creator_username, -- Allow NULL for malformed sub-agents
p."avatarUrl" AS creator_avatar, -- Allow NULL for malformed sub-agents
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
slv."isAvailable" AS is_available,
COALESCE(sl."useForOnboarding", false) AS "useForOnboarding"
FROM "StoreListing" sl
JOIN latest_versions lv
ON sl.id = lv."storeListingId"
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = lv."storeListingId"
AND slv.version = lv.max_version
AND slv."submissionStatus" = 'APPROVED'
JOIN "AgentGraph" a
ON slv."agentGraphId" = a.id
AND slv."agentGraphVersion" = a.version
LEFT JOIN "Profile" p
ON sl."owningUserId" = p."userId"
LEFT JOIN "mv_review_stats" rs
ON sl.id = rs."storeListingId"
LEFT JOIN "mv_agent_run_counts" ar
ON a.id = ar."agentGraphId"
LEFT JOIN agent_versions av
ON sl.id = av."storeListingId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true;
COMMIT;

View File

@@ -413,7 +413,6 @@ pydantic-settings = "^2.10.1"
pyjwt = {version = "^2.10.1", extras = ["crypto"]}
redis = "^6.2.0"
supabase = "^2.16.0"
tenacity = "^9.1.2"
uvicorn = "^0.35.0"
[package.source]

View File

@@ -45,6 +45,7 @@ model User {
AnalyticsDetails AnalyticsDetails[]
AnalyticsMetrics AnalyticsMetrics[]
CreditTransactions CreditTransaction[]
UserBalance UserBalance?
AgentPresets AgentPreset[]
LibraryAgents LibraryAgent[]
@@ -118,9 +119,9 @@ model AgentGraph {
createdAt DateTime @default(now())
updatedAt DateTime? @updatedAt
name String?
description String?
instructions String?
name String?
description String?
instructions String?
recommendedScheduleCron String?
isActive Boolean @default(true)
@@ -382,9 +383,9 @@ model AgentGraphExecution {
stats Json?
// Sharing fields
isShared Boolean @default(false)
shareToken String? @unique
sharedAt DateTime?
isShared Boolean @default(false)
shareToken String? @unique
sharedAt DateTime?
@@index([agentGraphId, agentGraphVersion])
@@index([userId, isDeleted, createdAt])
@@ -545,7 +546,7 @@ model CreditTransaction {
createdAt DateTime @default(now())
userId String
User User? @relation(fields: [userId], references: [id], onDelete: NoAction)
User User? @relation(fields: [userId], references: [id], onDelete: NoAction)
amount Int
type CreditTransactionType
@@ -587,6 +588,16 @@ model CreditRefundRequest {
////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////
model SearchTerms {
// User ID not being logged as this is anonymous analytics data
// Not using uuid as we want to minimise table size
id BigInt @id @default(autoincrement())
createdDate DateTime
searchTerm String
@@index([createdDate])
}
model Profile {
id String @id @default(uuid())
createdAt DateTime @default(now())
@@ -657,6 +668,7 @@ view StoreAgent {
rating Float
versions String[]
is_available Boolean @default(true)
useForOnboarding Boolean @default(false)
// Materialized views used (refreshed every 15 minutes via pg_cron):
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
@@ -734,6 +746,9 @@ model StoreListing {
// URL-friendly identifier for this agent (moved from StoreListingVersion)
slug String
// Allow this agent to be used during onboarding
useForOnboarding Boolean @default(false)
// The currently active version that should be shown to users
activeVersionId String? @unique
ActiveVersion StoreListingVersion? @relation("ActiveVersion", fields: [activeVersionId], references: [id])
@@ -769,13 +784,13 @@ model StoreListingVersion {
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
// Content fields
name String
subHeading String
videoUrl String?
imageUrls String[]
description String
name String
subHeading String
videoUrl String?
imageUrls String[]
description String
instructions String?
categories String[]
categories String[]
isFeatured Boolean @default(false)
@@ -873,6 +888,16 @@ model APIKey {
@@index([userId, status])
}
model UserBalance {
userId String @id
balance Int @default(0)
updatedAt DateTime @updatedAt
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@index([userId])
}
enum APIKeyStatus {
ACTIVE
REVOKED

View File

@@ -1,5 +1,5 @@
{
"email": "test@example.com",
"id": "test-user-id",
"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"name": "Test User"
}

View File

@@ -28,6 +28,6 @@
"recommended_schedule_cron": null,
"sub_graphs": [],
"trigger_setup_info": null,
"user_id": "test-user-id",
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"version": 1
}

View File

@@ -26,7 +26,7 @@
"recommended_schedule_cron": null,
"sub_graphs": [],
"trigger_setup_info": null,
"user_id": "test-user-id",
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"version": 1
}
]

View File

@@ -749,10 +749,11 @@ class TestDataCreator:
"""Add credits to users."""
print("Adding credits to users...")
credit_model = get_user_credit_model()
for user in self.users:
try:
# Get user-specific credit model
credit_model = await get_user_credit_model(user["id"])
# Skip credits for disabled credit model to avoid errors
if (
hasattr(credit_model, "__class__")

View File

@@ -11,7 +11,6 @@
NEXT_PUBLIC_LAUNCHDARKLY_ENABLED=false
NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID=687ab1372f497809b131e06e
NEXT_PUBLIC_SHOW_BILLING_PAGE=false
NEXT_PUBLIC_TURNSTILE=disabled
NEXT_PUBLIC_REACT_QUERY_DEVTOOL=true

View File

@@ -2,6 +2,7 @@ import { withSentryConfig } from "@sentry/nextjs";
/** @type {import('next').NextConfig} */
const nextConfig = {
productionBrowserSourceMaps: true,
images: {
domains: [
"images.unsplash.com",
@@ -74,6 +75,14 @@ export default isDevelopmentBuild
// since the source is public anyway :)
hideSourceMaps: false,
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
sourcemaps: {
disable: false, // Source maps are enabled by default
assets: ["**/*.js", "**/*.js.map"], // Specify which files to upload
ignore: ["**/node_modules/**"], // Files to exclude
deleteSourcemapsAfterUpload: true, // Security: delete after upload
},
// Automatically tree-shake Sentry logger statements to reduce bundle size
disableLogger: true,

View File

@@ -67,6 +67,12 @@ export default defineConfig({
useQuery: true,
},
},
"getV2Builder search": {
query: {
useInfinite: true,
useInfiniteQueryParam: "page",
},
},
},
},
},

View File

@@ -0,0 +1,93 @@
import type { GraphMeta } from "@/lib/autogpt-server-api";
import type {
BlockIOCredentialsSubSchema,
CredentialsMetaInput,
} from "@/lib/autogpt-server-api/types";
import type { InputValues } from "./types";
export function computeInitialAgentInputs(
agent: GraphMeta | null,
existingInputs?: InputValues | null,
): InputValues {
const properties = agent?.input_schema?.properties || {};
const result: InputValues = {};
Object.entries(properties).forEach(([key, subSchema]) => {
if (
existingInputs &&
key in existingInputs &&
existingInputs[key] != null
) {
result[key] = existingInputs[key];
return;
}
// GraphIOSubSchema.default is typed as string, but server may return other primitives
const def = (subSchema as unknown as { default?: string | number }).default;
result[key] = def ?? "";
});
return result;
}
export function getAgentCredentialsInputFields(agent: GraphMeta | null) {
const hasNoInputs =
!agent?.credentials_input_schema ||
typeof agent.credentials_input_schema !== "object" ||
!("properties" in agent.credentials_input_schema) ||
!agent.credentials_input_schema.properties;
if (hasNoInputs) return {};
return agent.credentials_input_schema.properties;
}
export function areAllCredentialsSet(
fields: Record<string, BlockIOCredentialsSubSchema>,
inputs: Record<string, CredentialsMetaInput | undefined>,
) {
const required = Object.keys(fields || {});
return required.every((k) => Boolean(inputs[k]));
}
type IsRunDisabledParams = {
agent: GraphMeta | null;
isRunning: boolean;
agentInputs: InputValues | null | undefined;
credentialsRequired: boolean;
credentialsSatisfied: boolean;
};
export function isRunDisabled({
agent,
isRunning,
agentInputs,
credentialsRequired,
credentialsSatisfied,
}: IsRunDisabledParams) {
const hasEmptyInput = Object.values(agentInputs || {}).some(
(value) => String(value).trim() === "",
);
if (hasEmptyInput) return true;
if (!agent) return true;
if (isRunning) return true;
if (credentialsRequired && !credentialsSatisfied) return true;
return false;
}
export function getSchemaDefaultCredentials(
schema: BlockIOCredentialsSubSchema,
): CredentialsMetaInput | undefined {
return schema.default as CredentialsMetaInput | undefined;
}
export function sanitizeCredentials(
map: Record<string, CredentialsMetaInput | undefined>,
): Record<string, CredentialsMetaInput> {
const sanitized: Record<string, CredentialsMetaInput> = {};
for (const [key, value] of Object.entries(map)) {
if (value) sanitized[key] = value;
}
return sanitized;
}

View File

@@ -13,13 +13,24 @@ import {
} from "@/components/__legacy__/ui/card";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
import type { InputValues } from "./types";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { cn } from "@/lib/utils";
import { Play } from "lucide-react";
import { useRouter } from "next/navigation";
import { useCallback, useEffect, useState } from "react";
import { useEffect, useState } from "react";
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/RunAgentInputs/RunAgentInputs";
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
import {
areAllCredentialsSet,
computeInitialAgentInputs,
getAgentCredentialsInputFields,
isRunDisabled,
getSchemaDefaultCredentials,
sanitizeCredentials,
} from "./helpers";
export default function Page() {
const { state, updateState, setStep } = useOnboarding(
@@ -30,13 +41,16 @@ export default function Page() {
const [agent, setAgent] = useState<GraphMeta | null>(null);
const [storeAgent, setStoreAgent] = useState<StoreAgentDetails | null>(null);
const [runningAgent, setRunningAgent] = useState(false);
const [inputCredentials, setInputCredentials] = useState<
Record<string, CredentialsMetaInput | undefined>
>({});
const { toast } = useToast();
const router = useRouter();
const api = useBackendAPI();
useEffect(() => {
setStep(5);
}, [setStep]);
}, []);
useEffect(() => {
if (!state?.selectedStoreListingVersionId) {
@@ -49,40 +63,36 @@ export default function Page() {
});
api
.getGraphMetaByStoreListingVersionID(state.selectedStoreListingVersionId)
.then((agent) => {
setAgent(agent);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const update: { [key: string]: any } = {};
// Set default values from schema
Object.entries(agent.input_schema.properties).forEach(
([key, value]) => {
// Skip if already set
if (state.agentInput && state.agentInput[key]) {
update[key] = state.agentInput[key];
return;
}
update[key] = value.type !== "null" ? value.default || "" : "";
},
.then((meta) => {
setAgent(meta);
const update = computeInitialAgentInputs(
meta,
(state.agentInput as unknown as InputValues) || null,
);
updateState({
agentInput: update,
});
updateState({ agentInput: update });
});
}, [api, setAgent, updateState, state?.selectedStoreListingVersionId]);
const setAgentInput = useCallback(
(key: string, value: string) => {
updateState({
agentInput: {
...state?.agentInput,
[key]: value,
},
});
},
[state?.agentInput, updateState],
const agentCredentialsInputFields = getAgentCredentialsInputFields(agent);
const credentialsRequired =
Object.keys(agentCredentialsInputFields || {}).length > 0;
const allCredentialsAreSet = areAllCredentialsSet(
agentCredentialsInputFields,
inputCredentials,
);
const runAgent = useCallback(async () => {
function setAgentInput(key: string, value: string) {
updateState({
agentInput: {
...state?.agentInput,
[key]: value,
},
});
}
async function runAgent() {
if (!agent) {
return;
}
@@ -95,6 +105,7 @@ export default function Page() {
libraryAgent.graph_id,
libraryAgent.graph_version,
state?.agentInput || {},
sanitizeCredentials(inputCredentials),
);
updateState({
onboardingAgentExecutionId: runID,
@@ -111,7 +122,7 @@ export default function Page() {
});
setRunningAgent(false);
}
}, [api, agent, router, state?.agentInput, storeAgent, updateState, toast]);
}
const runYourAgent = (
<div className="ml-[104px] w-[481px] pl-5">
@@ -221,6 +232,30 @@ export default function Page() {
<span className="mt-4 text-base font-normal leading-normal text-zinc-600">
When you&apos;re done, click <b>Run Agent</b>.
</span>
{Object.entries(agentCredentialsInputFields || {}).map(
([key, inputSubSchema]) => (
<div key={key} className="mt-4">
<CredentialsInput
schema={inputSubSchema}
selectedCredentials={
inputCredentials[key] ??
getSchemaDefaultCredentials(inputSubSchema)
}
onSelectCredentials={(value) =>
setInputCredentials((prev) => ({
...prev,
[key]: value,
}))
}
siblingInputs={
(state?.agentInput || undefined) as
| Record<string, any>
| undefined
}
/>
</div>
),
)}
<Card className="agpt-box mt-4">
<CardHeader>
<CardTitle className="font-poppins text-lg">Input</CardTitle>
@@ -250,13 +285,14 @@ export default function Page() {
variant="violet"
className="mt-8 w-[136px]"
loading={runningAgent}
disabled={
Object.values(state?.agentInput || {}).some(
(value) => String(value).trim() === "",
) ||
!agent ||
runningAgent
}
disabled={isRunDisabled({
agent,
isRunning: runningAgent,
agentInputs:
(state?.agentInput as unknown as InputValues) || null,
credentialsRequired,
credentialsSatisfied: allCredentialsAreSet,
})}
onClick={runAgent}
icon={<Play className="mr-2" size={18} />}
>

Some files were not shown because too many files have changed in this diff Show More