mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
43 Commits
fix/sql-in
...
swiftyos/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01bab66f5c | ||
|
|
30b2d6b50d | ||
|
|
97e77339fd | ||
|
|
6a360a49b1 | ||
|
|
045009a84a | ||
|
|
70cb7824fd | ||
|
|
df1d15fcfe | ||
|
|
eb022e50a7 | ||
|
|
431042a391 | ||
|
|
7c6a9146f0 | ||
|
|
0264cb56d3 | ||
|
|
ef552c189f | ||
|
|
e75cf2b765 | ||
|
|
0f6d1f54ee | ||
|
|
b3fe2b84ce | ||
|
|
e13861ad33 | ||
|
|
e2c24bd463 | ||
|
|
9f5afff83e | ||
|
|
ced61e2640 | ||
|
|
c9a7cc63da | ||
|
|
7afa01a168 | ||
|
|
2f9aba0420 | ||
|
|
ff4b0929e1 | ||
|
|
2230c76863 | ||
|
|
b3443e0549 | ||
|
|
b1364b1701 | ||
|
|
95d66a035c | ||
|
|
30cdf9f0d9 | ||
|
|
e47f0e7f2f | ||
|
|
13d71464a0 | ||
|
|
be1947f6d1 | ||
|
|
1afebcf96b | ||
|
|
d124c93ff8 | ||
|
|
bc5eb8a8a5 | ||
|
|
872ef5fdfb | ||
|
|
12382e7990 | ||
|
|
d68a3a1b53 | ||
|
|
863e213af3 | ||
|
|
c61af53a74 | ||
|
|
eb94503de8 | ||
|
|
ee4feff8c2 | ||
|
|
9147c2d6c8 | ||
|
|
a3af430c69 |
94
.github/copilot-instructions.md
vendored
94
.github/copilot-instructions.md
vendored
@@ -12,7 +12,6 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
@@ -24,17 +23,15 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
@@ -51,7 +48,6 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
@@ -62,7 +58,6 @@ cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
@@ -73,7 +68,6 @@ poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
@@ -87,27 +81,23 @@ pnpm storybook # Start component development server
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
@@ -118,7 +108,6 @@ pnpm storybook # Start component development server
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
@@ -132,7 +121,6 @@ pnpm storybook # Start component development server
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
@@ -148,7 +136,6 @@ pnpm storybook # Start component development server
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
@@ -159,13 +146,11 @@ pnpm storybook # Start component development server
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
@@ -175,7 +160,6 @@ pnpm storybook # Start component development server
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
@@ -183,7 +167,6 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
@@ -191,15 +174,13 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
@@ -208,7 +189,6 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
@@ -218,7 +198,6 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
@@ -226,76 +205,21 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
|
||||
**📖 Complete Frontend Guide**: See `autogpt_platform/frontend/CONTRIBUTING.md` and `autogpt_platform/frontend/.cursorrules` for comprehensive patterns and conventions.
|
||||
|
||||
**Quick Reference:**
|
||||
|
||||
**Component Structure:**
|
||||
|
||||
- Separate render logic from data/behavior
|
||||
- Structure: `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Exception: Small components (3-4 lines of logic) can be inline
|
||||
- Render-only components can be direct files without folders
|
||||
|
||||
**Data Fetching:**
|
||||
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Generated via Orval from backend OpenAPI spec
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
- Example: `useGetV2ListLibraryAgents`
|
||||
- Regenerate with: `pnpm generate:api`
|
||||
- **Never** use deprecated `BackendAPI` or `src/lib/autogpt-server-api/*`
|
||||
|
||||
**Code Conventions:**
|
||||
|
||||
- Use function declarations for components and handlers (not arrow functions)
|
||||
- Only arrow functions for small inline lambdas (map, filter, etc.)
|
||||
- Components: `PascalCase`, Hooks: `camelCase` with `use` prefix
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Minimal comments (code should be self-documenting)
|
||||
|
||||
**Styling:**
|
||||
|
||||
- Use Tailwind CSS utilities only
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
- Only use Phosphor Icons (`@phosphor-icons/react`)
|
||||
- Prefer design tokens over hardcoded values
|
||||
|
||||
**Error Handling:**
|
||||
|
||||
- Render errors: Use `<ErrorCard />` component
|
||||
- Mutation errors: Display with toast notifications
|
||||
- Manual exceptions: Use `Sentry.captureException()`
|
||||
- Global error boundaries already configured
|
||||
|
||||
**Testing:**
|
||||
|
||||
- Add/update Storybook stories for UI components (`pnpm storybook`)
|
||||
- Run Playwright E2E tests with `pnpm test`
|
||||
- Verify in Chromatic after PR
|
||||
|
||||
**Architecture:**
|
||||
|
||||
- Default to client components ("use client")
|
||||
- Server components only for SEO or extreme TTFB needs
|
||||
- Use React Query for server state (via generated hooks)
|
||||
- Co-locate UI state in components/hooks
|
||||
1. Components in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for component development
|
||||
4. Test user-facing features with Playwright E2E tests
|
||||
5. Update protected routes in middleware when needed
|
||||
|
||||
### Security Guidelines
|
||||
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
|
||||
The repository has comprehensive CI workflows that test:
|
||||
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
@@ -305,7 +229,6 @@ Match these patterns when developing locally - the copilot setup environment mir
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
@@ -314,9 +237,8 @@ This repository is actively developed with assistance from Claude (via CLAUDE.md
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
3
.github/workflows/platform-frontend-ci.yml
vendored
3
.github/workflows/platform-frontend-ci.yml
vendored
@@ -217,6 +217,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api
|
||||
|
||||
- name: Install Browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
|
||||
@@ -63,9 +63,6 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
@@ -78,23 +75,12 @@ pnpm storybook
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
|
||||
|
||||
**Key Frontend Conventions:**
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
@@ -109,16 +95,11 @@ pnpm types
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **UI Components**: Radix UI primitives with Tailwind CSS styling
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
### Key Concepts
|
||||
|
||||
@@ -172,7 +153,6 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
**Adding a new block:**
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
@@ -180,7 +160,6 @@ Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
@@ -201,20 +180,10 @@ ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
4. Test with Playwright if user-facing
|
||||
|
||||
### Security Implementation
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
docker compose up -d deps
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
|
||||
reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
docker compose logs -f deps
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
format:
|
||||
cd backend && poetry run format
|
||||
cd frontend && pnpm format
|
||||
cd frontend && pnpm lint
|
||||
|
||||
init-env:
|
||||
cp -n .env.default .env || true
|
||||
cd backend && cp -n .env.default .env || true
|
||||
cd frontend && cp -n .env.default .env || true
|
||||
|
||||
|
||||
# Run migrations for backend
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
cd frontend && pnpm dev
|
||||
|
||||
test-data:
|
||||
cd backend && poetry run python test/test_data_creator.py
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " reset-db - Reset the database by deleting the volume"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@@ -38,37 +38,6 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Running Just Core services
|
||||
|
||||
You can now run the following to enable just the core services.
|
||||
|
||||
```
|
||||
# For help
|
||||
make help
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
make start-core
|
||||
|
||||
# Stop core services
|
||||
make stop-core
|
||||
|
||||
# View logs from core services
|
||||
make logs-core
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
make format
|
||||
|
||||
# Run migrations for backend database
|
||||
make migrate
|
||||
|
||||
# Run backend server
|
||||
make run-backend
|
||||
|
||||
# Run frontend development server
|
||||
make run-frontend
|
||||
|
||||
```
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
Here are some useful Docker Compose commands for managing your AutoGPT Platform:
|
||||
|
||||
@@ -94,36 +94,42 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
config = LoggingConfig()
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
structured_logging = config.enable_cloud_logging or force_cloud_logging
|
||||
|
||||
# Console output handlers
|
||||
if not structured_logging:
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
else:
|
||||
# Use Google Cloud Structured Log Handler. Log entries are printed to stdout
|
||||
# in a JSON format which is automatically picked up by Google Cloud Logging.
|
||||
from google.cloud.logging.handlers import StructuredLogHandler
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports import (
|
||||
BackgroundThreadTransport,
|
||||
)
|
||||
|
||||
structured_log_handler = StructuredLogHandler(stream=sys.stdout)
|
||||
structured_log_handler.setLevel(config.level)
|
||||
log_handlers.append(structured_log_handler)
|
||||
client = google.cloud.logging.Client()
|
||||
# Use BackgroundThreadTransport to prevent blocking the main thread
|
||||
# and deadlocks when gRPC calls to Google Cloud Logging hang
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=BackgroundThreadTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
log_handlers.append(cloud_handler)
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
@@ -179,13 +185,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
format=(
|
||||
"%(levelname)s %(message)s"
|
||||
if structured_logging
|
||||
else (
|
||||
DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT
|
||||
)
|
||||
),
|
||||
format=DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT,
|
||||
level=config.level,
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
18
autogpt_platform/autogpt_libs/poetry.lock
generated
18
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1719,6 +1719,22 @@ files = [
|
||||
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
strenum = ">=0.4.15,<0.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "9.1.2"
|
||||
description = "Retry code until it succeeds"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138"},
|
||||
{file = "tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
doc = ["reno", "sphinx"]
|
||||
test = ["pytest", "tornado (>=4.5)", "typeguard"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
@@ -1929,4 +1945,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
content-hash = "5ec9e6cd2ef7524a356586354755215699e7b37b9bbdfbabc9c73b43085915f4"
|
||||
|
||||
@@ -19,6 +19,7 @@ pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
tenacity = "^9.1.2"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -47,7 +47,6 @@ RUN poetry install --no-ansi --no-root
|
||||
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
@@ -93,7 +92,6 @@ FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@cached(ttl_seconds=3600) # Cache blocks for 1 hour
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -66,7 +66,6 @@ class AddToDictionaryBlock(Block):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
advanced=False,
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
|
||||
@@ -4,13 +4,13 @@ import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._auth import (
|
||||
@@ -114,9 +114,10 @@ class ReadDiscordMessagesBlock(Block):
|
||||
if message.attachments:
|
||||
attachment = message.attachments[0] # Process the first attachment
|
||||
if attachment.filename.endswith((".txt", ".py")):
|
||||
response = await Requests().get(attachment.url)
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(attachment.url) as response:
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
|
||||
await client.close()
|
||||
|
||||
@@ -170,11 +171,11 @@ class SendDiscordMessageBlock(Block):
|
||||
description="The content of the message to send"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="Channel ID or channel name to send the message to"
|
||||
description="The name of the channel the message will be sent to"
|
||||
)
|
||||
server_name: str = SchemaField(
|
||||
description="Server name (only needed if using channel name)",
|
||||
advanced=True,
|
||||
description="The name of the server where the channel is located",
|
||||
advanced=True, # Optional field for server name
|
||||
default="",
|
||||
)
|
||||
|
||||
@@ -230,49 +231,25 @@ class SendDiscordMessageBlock(Block):
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Logged in as {client.user}")
|
||||
channel = None
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for channel in guild.text_channels:
|
||||
if channel.name == channel_name:
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk)
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = (
|
||||
str(last_message.id) if last_message else ""
|
||||
)
|
||||
result["channel_id"] = str(channel.id)
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Try to parse as channel ID first
|
||||
try:
|
||||
channel_id = int(channel_name)
|
||||
channel = client.get_channel(channel_id)
|
||||
except ValueError:
|
||||
# Not a valid ID, will try name lookup
|
||||
pass
|
||||
|
||||
# If not found by ID (or not an ID), try name lookup
|
||||
if not channel:
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for ch in guild.text_channels:
|
||||
if ch.name == channel_name:
|
||||
channel = ch
|
||||
break
|
||||
if channel:
|
||||
break
|
||||
|
||||
if not channel:
|
||||
result["status"] = f"Channel not found: {channel_name}"
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Type check - ensure it's a text channel that can send messages
|
||||
if not hasattr(channel, "send"):
|
||||
result["status"] = (
|
||||
f"Channel {channel_name} cannot receive messages (not a text channel)"
|
||||
)
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk) # type: ignore
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = str(last_message.id) if last_message else ""
|
||||
result["channel_id"] = str(channel.id)
|
||||
result["status"] = "Channel not found"
|
||||
await client.close()
|
||||
|
||||
await client.start(token)
|
||||
@@ -698,15 +675,16 @@ class SendDiscordFileBlock(Block):
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL - download the file
|
||||
response = await Requests().get(file)
|
||||
file_bytes = response.content
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(file) as response:
|
||||
file_bytes = await response.read()
|
||||
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
else:
|
||||
# Local file path - read from stored media file
|
||||
# This would be a path from a previous block's output
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import loads
|
||||
from backend.util.json import json
|
||||
|
||||
|
||||
class StepThroughItemsBlock(Block):
|
||||
@@ -68,7 +68,7 @@ class StepThroughItemsBlock(Block):
|
||||
raise ValueError(
|
||||
f"Input too large: {len(data)} bytes > {MAX_ITEM_SIZE} bytes"
|
||||
)
|
||||
items = loads(data)
|
||||
items = json.loads(data)
|
||||
else:
|
||||
items = data
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from typing import List
|
||||
from urllib.parse import quote
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
JinaCredentials,
|
||||
JinaCredentialsField,
|
||||
@@ -13,12 +10,6 @@ from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class Reference(TypedDict):
|
||||
url: str
|
||||
keyQuote: str
|
||||
isSupportive: bool
|
||||
|
||||
|
||||
class FactCheckerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
statement: str = SchemaField(
|
||||
@@ -32,10 +23,6 @@ class FactCheckerBlock(Block):
|
||||
)
|
||||
result: bool = SchemaField(description="The result of the factuality check")
|
||||
reason: str = SchemaField(description="The reason for the factuality result")
|
||||
references: List[Reference] = SchemaField(
|
||||
description="List of references supporting or contradicting the statement",
|
||||
default=[],
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the check fails")
|
||||
|
||||
def __init__(self):
|
||||
@@ -66,11 +53,5 @@ class FactCheckerBlock(Block):
|
||||
yield "factuality", data["factuality"]
|
||||
yield "result", data["result"]
|
||||
yield "reason", data["reason"]
|
||||
|
||||
# Yield references if present in the response
|
||||
if "references" in data:
|
||||
yield "references", data["references"]
|
||||
else:
|
||||
yield "references", []
|
||||
else:
|
||||
raise RuntimeError(f"Expected 'data' key not found in response: {data}")
|
||||
|
||||
@@ -62,10 +62,10 @@ TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
|
||||
title="Mock Linear API key",
|
||||
username="mock-linear-username",
|
||||
access_token=SecretStr("mock-linear-access-token"),
|
||||
access_token_expires_at=1672531200, # Mock expiration time for short-lived token
|
||||
access_token_expires_at=None,
|
||||
refresh_token=SecretStr("mock-linear-refresh-token"),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=["read", "write"],
|
||||
scopes=["mock-linear-scopes"],
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_API_KEY = APIKeyCredentials(
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
Linear OAuth handler implementation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@@ -40,9 +38,8 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://linear.app/oauth/authorize"
|
||||
self.token_url = "https://api.linear.app/oauth/token"
|
||||
self.token_url = "https://api.linear.app/oauth/token" # Correct token URL
|
||||
self.revoke_url = "https://api.linear.app/oauth/revoke"
|
||||
self.migrate_url = "https://api.linear.app/oauth/migrate_old_token"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
@@ -85,84 +82,19 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
return True # Linear doesn't return JSON on successful revoke
|
||||
|
||||
async def migrate_old_token(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""
|
||||
Migrate an old long-lived token to a new short-lived token with refresh token.
|
||||
|
||||
This uses Linear's /oauth/migrate_old_token endpoint to exchange current
|
||||
long-lived tokens for short-lived tokens with refresh tokens without
|
||||
requiring users to re-authorize.
|
||||
"""
|
||||
if not credentials.access_token:
|
||||
raise ValueError("No access token to migrate")
|
||||
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
self.migrate_url, data=request_body, headers=headers
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("error", "Unknown error")
|
||||
error_description = error_data.get("error_description", "")
|
||||
if error_description:
|
||||
error_message = f"{error_message}: {error_description}"
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
raise LinearAPIException(
|
||||
f"Failed to migrate Linear token ({response.status}): {error_message}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
# Extract token expiration
|
||||
now = int(time.time())
|
||||
expires_in = token_data.get("expires_in")
|
||||
access_token_expires_at = None
|
||||
if expires_in:
|
||||
access_token_expires_at = now + expires_in
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
username=credentials.username,
|
||||
access_token=token_data["access_token"],
|
||||
scopes=credentials.scopes, # Preserve original scopes
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
access_token_expires_at=access_token_expires_at,
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
new_credentials.id = credentials.id
|
||||
return new_credentials
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError(
|
||||
"No refresh token available. Token may need to be migrated to the new refresh token system."
|
||||
)
|
||||
"No refresh token available."
|
||||
) # Linear uses non-expiring tokens
|
||||
|
||||
return await self._request_tokens(
|
||||
{
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
current_credentials=credentials,
|
||||
}
|
||||
)
|
||||
|
||||
async def _request_tokens(
|
||||
@@ -170,33 +102,16 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
params: dict[str, str],
|
||||
current_credentials: Optional[OAuth2Credentials] = None,
|
||||
) -> OAuth2Credentials:
|
||||
# Determine if this is a refresh token request
|
||||
is_refresh = params.get("grant_type") == "refresh_token"
|
||||
|
||||
# Build request body with appropriate grant_type
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"grant_type": "authorization_code", # Ensure grant_type is correct
|
||||
**params,
|
||||
}
|
||||
|
||||
# Set default grant_type if not provided
|
||||
if "grant_type" not in request_body:
|
||||
request_body["grant_type"] = "authorization_code"
|
||||
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
# For refresh token requests, support HTTP Basic Authentication as recommended
|
||||
if is_refresh:
|
||||
# Option 1: Use HTTP Basic Auth (preferred by Linear)
|
||||
client_credentials = f"{self.client_id}:{self.client_secret}"
|
||||
encoded_credentials = base64.b64encode(client_credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
|
||||
# Remove client credentials from body when using Basic Auth
|
||||
request_body.pop("client_id", None)
|
||||
request_body.pop("client_secret", None)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded"
|
||||
} # Correct header for token request
|
||||
response = await Requests().post(
|
||||
self.token_url, data=request_body, headers=headers
|
||||
)
|
||||
@@ -205,9 +120,6 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("error", "Unknown error")
|
||||
error_description = error_data.get("error_description", "")
|
||||
if error_description:
|
||||
error_message = f"{error_message}: {error_description}"
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
raise LinearAPIException(
|
||||
@@ -217,84 +129,27 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
# Extract token expiration if provided (for new refresh token implementation)
|
||||
now = int(time.time())
|
||||
expires_in = token_data.get("expires_in")
|
||||
access_token_expires_at = None
|
||||
if expires_in:
|
||||
access_token_expires_at = now + expires_in
|
||||
|
||||
# Get username - preserve from current credentials if refreshing
|
||||
username = None
|
||||
if current_credentials and is_refresh:
|
||||
username = current_credentials.username
|
||||
elif "user" in token_data:
|
||||
username = token_data["user"].get("name", "Unknown User")
|
||||
else:
|
||||
# Fetch username using the access token
|
||||
username = await self._request_username(token_data["access_token"])
|
||||
|
||||
# Note: Linear access tokens do not expire, so we set expires_at to None
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=current_credentials.title if current_credentials else None,
|
||||
username=username or "Unknown User",
|
||||
username=token_data.get("user", {}).get(
|
||||
"name", "Unknown User"
|
||||
), # extract name or set appropriate
|
||||
access_token=token_data["access_token"],
|
||||
scopes=(
|
||||
token_data["scope"].split(",")
|
||||
if "scope" in token_data
|
||||
else (current_credentials.scopes if current_credentials else [])
|
||||
),
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
access_token_expires_at=access_token_expires_at,
|
||||
refresh_token_expires_at=None, # Linear doesn't provide refresh token expiration
|
||||
scopes=token_data["scope"].split(
|
||||
","
|
||||
), # Linear returns comma-separated scopes
|
||||
refresh_token=token_data.get(
|
||||
"refresh_token"
|
||||
), # Linear uses non-expiring tokens so this might be null
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
if current_credentials:
|
||||
new_credentials.id = current_credentials.id
|
||||
|
||||
return new_credentials
|
||||
|
||||
async def get_access_token(self, credentials: OAuth2Credentials) -> str:
|
||||
"""
|
||||
Returns a valid access token, handling migration and refresh as needed.
|
||||
|
||||
This overrides the base implementation to handle Linear's token migration
|
||||
from old long-lived tokens to new short-lived tokens with refresh tokens.
|
||||
"""
|
||||
# If token has no expiration and no refresh token, it might be an old token
|
||||
# that needs migration
|
||||
if (
|
||||
credentials.access_token_expires_at is None
|
||||
and credentials.refresh_token is None
|
||||
):
|
||||
try:
|
||||
# Attempt to migrate the old token
|
||||
migrated_credentials = await self.migrate_old_token(credentials)
|
||||
# Update the credentials store would need to be handled by the caller
|
||||
# For now, use the migrated credentials for this request
|
||||
credentials = migrated_credentials
|
||||
except LinearAPIException:
|
||||
# Migration failed, try to use the old token as-is
|
||||
# This maintains backward compatibility
|
||||
pass
|
||||
|
||||
# Use the standard refresh logic from the base class
|
||||
if self.needs_refresh(credentials):
|
||||
credentials = await self.refresh_tokens(credentials)
|
||||
|
||||
return credentials.access_token.get_secret_value()
|
||||
|
||||
def needs_migration(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""
|
||||
Check if credentials represent an old long-lived token that needs migration.
|
||||
|
||||
Old tokens have no expiration time and no refresh token.
|
||||
"""
|
||||
return (
|
||||
credentials.access_token_expires_at is None
|
||||
and credentials.refresh_token is None
|
||||
)
|
||||
|
||||
async def _request_username(self, access_token: str) -> Optional[str]:
|
||||
# Use the LinearClient to fetch user details using GraphQL
|
||||
from ._api import LinearClient
|
||||
|
||||
@@ -37,5 +37,5 @@ class Project(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
priority: int
|
||||
progress: float
|
||||
content: str | None
|
||||
progress: int
|
||||
content: str
|
||||
|
||||
@@ -102,8 +102,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -216,12 +217,15 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-5-sonnet-20241022
|
||||
LlmModel.CLAUDE_3_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-5-haiku-20241022
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096
|
||||
), # claude-3-haiku-20240307
|
||||
@@ -1554,9 +1558,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
("prompt", list),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: dict(
|
||||
response="The 2020 World Series was played at Globe Life Field in Arlington, Texas."
|
||||
)
|
||||
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1585,7 +1587,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
),
|
||||
credentials=credentials,
|
||||
)
|
||||
yield "response", response["response"]
|
||||
yield "response", response
|
||||
yield "prompt", self.prompt
|
||||
|
||||
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
# flake8: noqa: E501
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[Perplexity-Block]")
|
||||
|
||||
|
||||
class PerplexityModel(str, Enum):
|
||||
"""Perplexity sonar models available via OpenRouter"""
|
||||
|
||||
SONAR = "perplexity/sonar"
|
||||
SONAR_PRO = "perplexity/sonar-pro"
|
||||
SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
|
||||
|
||||
PerplexityCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.OPEN_ROUTER], Literal["api_key"]
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="test-perplexity-creds",
|
||||
provider="open_router",
|
||||
api_key=SecretStr("mock-openrouter-api-key"),
|
||||
title="Mock OpenRouter API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def PerplexityCredentialsField() -> PerplexityCredentials:
|
||||
return CredentialsField(
|
||||
description="OpenRouter API key for accessing Perplexity models.",
|
||||
)
|
||||
|
||||
|
||||
class PerplexityBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="The query to send to the Perplexity model.",
|
||||
placeholder="Enter your query here...",
|
||||
)
|
||||
model: PerplexityModel = SchemaField(
|
||||
title="Perplexity Model",
|
||||
default=PerplexityModel.SONAR,
|
||||
description="The Perplexity sonar model to use.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: PerplexityCredentials = PerplexityCredentialsField()
|
||||
system_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
description="Optional system prompt to provide context to the model.",
|
||||
advanced=True,
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(
|
||||
description="The response from the Perplexity model."
|
||||
)
|
||||
annotations: list[dict[str, Any]] = SchemaField(
|
||||
description="List of URL citations and annotations from the response."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c8a5f2e9-8b3d-4a7e-9f6c-1d5e3c9b7a4f",
|
||||
description="Query Perplexity's sonar models with real-time web search capabilities and receive annotated responses with source citations.",
|
||||
categories={BlockCategory.AI, BlockCategory.SEARCH},
|
||||
input_schema=PerplexityBlock.Input,
|
||||
output_schema=PerplexityBlock.Output,
|
||||
test_input={
|
||||
"prompt": "What is the weather today?",
|
||||
"model": PerplexityModel.SONAR,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("response", "The weather varies by location..."),
|
||||
("annotations", list),
|
||||
],
|
||||
test_mock={
|
||||
"call_perplexity": lambda *args, **kwargs: {
|
||||
"response": "The weather varies by location...",
|
||||
"annotations": [
|
||||
{
|
||||
"type": "url_citation",
|
||||
"url_citation": {
|
||||
"title": "weather.com",
|
||||
"url": "https://weather.com",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
self.execution_stats = NodeExecutionStats()
|
||||
|
||||
async def call_perplexity(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
model: PerplexityModel,
|
||||
prompt: str,
|
||||
system_prompt: str = "",
|
||||
max_tokens: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Call Perplexity via OpenRouter and extract annotations."""
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=model.value,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No response from Perplexity via OpenRouter.")
|
||||
|
||||
# Extract the response content
|
||||
response_content = response.choices[0].message.content or ""
|
||||
|
||||
# Extract annotations if present in the message
|
||||
annotations = []
|
||||
if hasattr(response.choices[0].message, "annotations"):
|
||||
# If annotations are directly available
|
||||
annotations = response.choices[0].message.annotations
|
||||
else:
|
||||
# Check if there's a raw response with annotations
|
||||
raw = getattr(response.choices[0].message, "_raw_response", None)
|
||||
if isinstance(raw, dict) and "annotations" in raw:
|
||||
annotations = raw["annotations"]
|
||||
|
||||
if not annotations and hasattr(response, "model_extra"):
|
||||
# Check model_extra for annotations
|
||||
model_extra = response.model_extra
|
||||
if isinstance(model_extra, dict):
|
||||
# Check in choices
|
||||
if "choices" in model_extra and len(model_extra["choices"]) > 0:
|
||||
choice = model_extra["choices"][0]
|
||||
if "message" in choice and "annotations" in choice["message"]:
|
||||
annotations = choice["message"]["annotations"]
|
||||
|
||||
# Also check the raw response object for annotations
|
||||
if not annotations:
|
||||
raw = getattr(response, "_raw_response", None)
|
||||
if isinstance(raw, dict):
|
||||
# Check various possible locations for annotations
|
||||
if "annotations" in raw:
|
||||
annotations = raw["annotations"]
|
||||
elif "choices" in raw and len(raw["choices"]) > 0:
|
||||
choice = raw["choices"][0]
|
||||
if "message" in choice and "annotations" in choice["message"]:
|
||||
annotations = choice["message"]["annotations"]
|
||||
|
||||
# Update execution stats
|
||||
if response.usage:
|
||||
self.execution_stats.input_token_count = response.usage.prompt_tokens
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.completion_tokens
|
||||
)
|
||||
|
||||
return {"response": response_content, "annotations": annotations or []}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Perplexity: {e}")
|
||||
raise
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
logger.debug(f"Running Perplexity block with model: {input_data.model}")
|
||||
|
||||
try:
|
||||
result = await self.call_perplexity(
|
||||
credentials=credentials,
|
||||
model=input_data.model,
|
||||
prompt=input_data.prompt,
|
||||
system_prompt=input_data.system_prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
|
||||
yield "response", result["response"]
|
||||
yield "annotations", result["annotations"]
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling Perplexity: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
yield "error", error_msg
|
||||
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
@@ -8,7 +10,6 @@ import pydantic
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class RSSEntry(pydantic.BaseModel):
|
||||
@@ -102,29 +103,35 @@ class ReadRSSFeedBlock(Block):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def parse_feed(url: str) -> dict[str, Any]:
|
||||
def parse_feed(url: str) -> dict[str, Any]:
|
||||
# Security fix: Add protection against memory exhaustion attacks
|
||||
MAX_FEED_SIZE = 10 * 1024 * 1024 # 10MB limit for RSS feeds
|
||||
|
||||
# Download feed content with size limit
|
||||
# Validate URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError(f"Invalid URL scheme: {parsed_url.scheme}")
|
||||
|
||||
# Download with size limit
|
||||
try:
|
||||
response = await Requests(raise_for_status=True).get(url)
|
||||
with urllib.request.urlopen(url, timeout=30) as response:
|
||||
# Check content length if available
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
|
||||
)
|
||||
|
||||
# Check content length if available
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
|
||||
)
|
||||
# Read with size limit
|
||||
content = response.read(MAX_FEED_SIZE + 1)
|
||||
if len(content) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit"
|
||||
)
|
||||
|
||||
# Get content with size limit
|
||||
content = response.content
|
||||
if len(content) > MAX_FEED_SIZE:
|
||||
raise ValueError(f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit")
|
||||
|
||||
# Parse with feedparser using the validated content
|
||||
# feedparser has built-in protection against XML attacks
|
||||
return feedparser.parse(content) # type: ignore
|
||||
# Parse with feedparser using the validated content
|
||||
# feedparser has built-in protection against XML attacks
|
||||
return feedparser.parse(content) # type: ignore
|
||||
except Exception as e:
|
||||
# Log error and return empty feed
|
||||
logging.warning(f"Failed to parse RSS feed from {url}: {e}")
|
||||
@@ -138,7 +145,7 @@ class ReadRSSFeedBlock(Block):
|
||||
while keep_going:
|
||||
keep_going = input_data.run_continuously
|
||||
|
||||
feed = await self.parse_feed(input_data.rss_url)
|
||||
feed = self.parse_feed(input_data.rss_url)
|
||||
all_entries = []
|
||||
|
||||
for entry in feed["entries"]:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
@@ -27,13 +26,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="coroutine 'close_litellm_async_clients' was never awaited",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
@@ -362,7 +362,7 @@ class TestLLMStatsTracking:
|
||||
assert block.execution_stats.llm_call_count == 1
|
||||
|
||||
# Check output
|
||||
assert outputs["response"] == "AI response to conversation"
|
||||
assert outputs["response"] == {"response": "AI response to conversation"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_list_generator_with_retries(self):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from youtube_transcript_api._api import YouTubeTranscriptApi
|
||||
from youtube_transcript_api._errors import NoTranscriptFound
|
||||
from youtube_transcript_api._transcripts import FetchedTranscript
|
||||
from youtube_transcript_api.formatters import TextFormatter
|
||||
|
||||
@@ -65,29 +64,7 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def get_transcript(video_id: str) -> FetchedTranscript:
|
||||
"""
|
||||
Get transcript for a video, preferring English but falling back to any available language.
|
||||
|
||||
:param video_id: The YouTube video ID
|
||||
:return: The fetched transcript
|
||||
:raises: Any exception except NoTranscriptFound for requested languages
|
||||
"""
|
||||
api = YouTubeTranscriptApi()
|
||||
try:
|
||||
# Try to get English transcript first (default behavior)
|
||||
return api.fetch(video_id=video_id)
|
||||
except NoTranscriptFound:
|
||||
# If English is not available, get the first available transcript
|
||||
transcript_list = api.list(video_id)
|
||||
# Try manually created transcripts first, then generated ones
|
||||
available_transcripts = list(
|
||||
transcript_list._manually_created_transcripts.values()
|
||||
) + list(transcript_list._generated_transcripts.values())
|
||||
if available_transcripts:
|
||||
# Fetch the first available transcript
|
||||
return available_transcripts[0].fetch()
|
||||
# If no transcripts at all, re-raise the original error
|
||||
raise
|
||||
return YouTubeTranscriptApi().fetch(video_id=video_id)
|
||||
|
||||
@staticmethod
|
||||
def format_transcript(transcript: FetchedTranscript) -> str:
|
||||
|
||||
@@ -45,6 +45,9 @@ class MainApp(AppProcess):
|
||||
|
||||
app.main(silent=True)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from typing import Type
|
||||
|
||||
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
|
||||
from backend.blocks.ai_shortform_video_block import (
|
||||
AIAdMakerVideoCreatorBlock,
|
||||
AIScreenshotToVideoAdBlock,
|
||||
AIShortformVideoCreatorBlock,
|
||||
)
|
||||
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.apollo.person import GetPersonDetailBlock
|
||||
@@ -73,9 +69,10 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
@@ -325,31 +322,7 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
],
|
||||
AIShortformVideoCreatorBlock: [
|
||||
BlockCost(
|
||||
cost_amount=307,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": revid_credentials.id,
|
||||
"provider": revid_credentials.provider,
|
||||
"type": revid_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
AIAdMakerVideoCreatorBlock: [
|
||||
BlockCost(
|
||||
cost_amount=714,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": revid_credentials.id,
|
||||
"provider": revid_credentials.provider,
|
||||
"type": revid_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
AIScreenshotToVideoAdBlock: [
|
||||
BlockCost(
|
||||
cost_amount=612,
|
||||
cost_amount=50,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": revid_credentials.id,
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import stripe
|
||||
from prisma import Json
|
||||
from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
@@ -12,12 +13,16 @@ from prisma.enums import (
|
||||
OnboardingStep,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
@@ -31,8 +36,7 @@ from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -45,10 +49,6 @@ stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
# Constants for test compatibility
|
||||
POSTGRES_INT_MAX = 2147483647
|
||||
POSTGRES_INT_MIN = -2147483648
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
@@ -139,20 +139,14 @@ class UserCreditBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def onboarding_reward(
|
||||
self, user_id: str, credits: int, step: OnboardingStep
|
||||
) -> bool:
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
"""
|
||||
Reward the user with credits for completing an onboarding step.
|
||||
Won't reward if the user has already received credits for the step.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
credits (int): The amount to reward.
|
||||
step (OnboardingStep): The onboarding step.
|
||||
|
||||
Returns:
|
||||
bool: True if rewarded, False if already rewarded.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -242,12 +236,6 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
Returns the current balance of the user & the latest balance snapshot time.
|
||||
"""
|
||||
# Check UserBalance first for efficiency and consistency
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
if user_balance:
|
||||
return user_balance.balance, user_balance.updatedAt
|
||||
|
||||
# Fallback to transaction history computation if UserBalance doesn't exist
|
||||
top_time = self.time_now()
|
||||
snapshot = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
@@ -262,86 +250,72 @@ class UserCreditBase(ABC):
|
||||
snapshot_balance = snapshot.runningBalance or 0 if snapshot else 0
|
||||
snapshot_time = snapshot.createdAt if snapshot else datetime_min
|
||||
|
||||
return snapshot_balance, snapshot_time
|
||||
# Get transactions after the snapshot, this should not exist, but just in case.
|
||||
transactions = await CreditTransaction.prisma().group_by(
|
||||
by=["userId"],
|
||||
sum={"amount": True},
|
||||
max={"createdAt": True},
|
||||
where={
|
||||
"userId": user_id,
|
||||
"createdAt": {
|
||||
"gt": snapshot_time,
|
||||
"lte": top_time,
|
||||
},
|
||||
"isActive": True,
|
||||
},
|
||||
)
|
||||
transaction_balance = (
|
||||
int(transactions[0].get("_sum", {}).get("amount", 0) + snapshot_balance)
|
||||
if transactions
|
||||
else snapshot_balance
|
||||
)
|
||||
transaction_time = (
|
||||
datetime.fromisoformat(
|
||||
str(transactions[0].get("_max", {}).get("createdAt", datetime_min))
|
||||
)
|
||||
if transactions
|
||||
else snapshot_time
|
||||
)
|
||||
return transaction_balance, transaction_time
|
||||
|
||||
@func_retry
|
||||
async def _enable_transaction(
|
||||
self,
|
||||
transaction_key: str,
|
||||
user_id: str,
|
||||
metadata: SafeJson,
|
||||
metadata: Json,
|
||||
new_transaction_key: str | None = None,
|
||||
):
|
||||
# First check if transaction exists and is inactive (safety check)
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"transactionKey": transaction_key,
|
||||
"userId": user_id,
|
||||
"isActive": False,
|
||||
}
|
||||
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": transaction_key, "userId": user_id}
|
||||
)
|
||||
if not transaction:
|
||||
# Transaction doesn't exist or is already active, return early
|
||||
return None
|
||||
if transaction.isActive:
|
||||
return
|
||||
|
||||
# Atomic operation to enable transaction and update user balance using UserBalance
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
WITH user_balance_lock AS (
|
||||
SELECT
|
||||
$2::text as userId,
|
||||
COALESCE(
|
||||
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $2 FOR UPDATE),
|
||||
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
|
||||
(SELECT COALESCE(ct."runningBalance", 0)
|
||||
FROM {schema_prefix}"CreditTransaction" ct
|
||||
WHERE ct."userId" = $2
|
||||
AND ct."isActive" = true
|
||||
AND ct."runningBalance" IS NOT NULL
|
||||
ORDER BY ct."createdAt" DESC
|
||||
LIMIT 1),
|
||||
0
|
||||
) as balance
|
||||
),
|
||||
transaction_check AS (
|
||||
SELECT * FROM {schema_prefix}"CreditTransaction"
|
||||
WHERE "transactionKey" = $1 AND "userId" = $2 AND "isActive" = false
|
||||
),
|
||||
balance_update AS (
|
||||
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
|
||||
SELECT
|
||||
$2::text,
|
||||
user_balance_lock.balance + transaction_check.amount,
|
||||
CURRENT_TIMESTAMP
|
||||
FROM user_balance_lock, transaction_check
|
||||
ON CONFLICT ("userId") DO UPDATE SET
|
||||
"balance" = EXCLUDED."balance",
|
||||
"updatedAt" = EXCLUDED."updatedAt"
|
||||
RETURNING "balance", "updatedAt"
|
||||
),
|
||||
transaction_update AS (
|
||||
UPDATE {schema_prefix}"CreditTransaction"
|
||||
SET "transactionKey" = COALESCE($4, $1),
|
||||
"isActive" = true,
|
||||
"runningBalance" = balance_update.balance,
|
||||
"createdAt" = balance_update."updatedAt",
|
||||
"metadata" = $3::jsonb
|
||||
FROM balance_update, transaction_check
|
||||
WHERE {schema_prefix}"CreditTransaction"."transactionKey" = transaction_check."transactionKey"
|
||||
AND {schema_prefix}"CreditTransaction"."userId" = transaction_check."userId"
|
||||
RETURNING {schema_prefix}"CreditTransaction"."runningBalance"
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}"):
|
||||
|
||||
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": transaction_key, "userId": user_id}
|
||||
)
|
||||
SELECT "runningBalance" as balance FROM transaction_update;
|
||||
""",
|
||||
transaction_key, # $1
|
||||
user_id, # $2
|
||||
dumps(metadata.data), # $3 - use pre-serialized JSON string for JSONB
|
||||
new_transaction_key, # $4
|
||||
)
|
||||
if transaction.isActive:
|
||||
return
|
||||
|
||||
if result:
|
||||
# UserBalance is already updated by the CTE
|
||||
return result[0]["balance"]
|
||||
user_balance, _ = await self._get_credits(user_id)
|
||||
await CreditTransaction.prisma().update(
|
||||
where={
|
||||
"creditTransactionIdentifier": {
|
||||
"transactionKey": transaction_key,
|
||||
"userId": user_id,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"transactionKey": new_transaction_key or transaction_key,
|
||||
"isActive": True,
|
||||
"runningBalance": user_balance + transaction.amount,
|
||||
"createdAt": self.time_now(),
|
||||
"metadata": metadata,
|
||||
},
|
||||
)
|
||||
|
||||
async def _add_transaction(
|
||||
self,
|
||||
@@ -352,54 +326,12 @@ class UserCreditBase(ABC):
|
||||
transaction_key: str | None = None,
|
||||
ceiling_balance: int | None = None,
|
||||
fail_insufficient_credits: bool = True,
|
||||
metadata: SafeJson = SafeJson({}),
|
||||
metadata: Json = SafeJson({}),
|
||||
) -> tuple[int, str]:
|
||||
"""
|
||||
Add a new transaction for the user.
|
||||
This is the only method that should be used to add a new transaction.
|
||||
|
||||
ATOMIC OPERATION DESIGN DECISION:
|
||||
================================
|
||||
This method uses PostgreSQL row-level locking (FOR UPDATE) for atomic credit operations.
|
||||
After extensive analysis of concurrency patterns and correctness requirements, we determined
|
||||
that the FOR UPDATE approach is necessary despite the latency overhead.
|
||||
|
||||
WHY FOR UPDATE LOCKING IS REQUIRED:
|
||||
----------------------------------
|
||||
1. **Data Consistency**: Credit operations must be ACID-compliant. The balance check,
|
||||
calculation, and update must be atomic to prevent race conditions where:
|
||||
- Multiple spend operations could exceed available balance
|
||||
- Lost update problems could occur with concurrent top-ups
|
||||
- Refunds could create negative balances incorrectly
|
||||
|
||||
2. **Serializability**: FOR UPDATE ensures operations are serialized at the database level,
|
||||
guaranteeing that each transaction sees a consistent view of the balance before applying changes.
|
||||
|
||||
3. **Correctness Over Performance**: Financial operations require absolute correctness.
|
||||
The ~10-50ms latency increase from row locking is acceptable for the guarantee that
|
||||
no user will ever have an incorrect balance due to race conditions.
|
||||
|
||||
4. **PostgreSQL Optimization**: Modern PostgreSQL versions optimize row locks efficiently.
|
||||
The performance cost is minimal compared to the complexity and risk of lock-free approaches.
|
||||
|
||||
ALTERNATIVES CONSIDERED AND REJECTED:
|
||||
------------------------------------
|
||||
- **Optimistic Concurrency**: Using version numbers or timestamps would require complex
|
||||
retry logic and could still fail under high contention scenarios.
|
||||
- **Application-Level Locking**: Redis locks or similar would add network overhead and
|
||||
single points of failure while being less reliable than database locks.
|
||||
- **Event Sourcing**: Would require complete architectural changes and eventual consistency
|
||||
models that don't fit our real-time balance requirements.
|
||||
|
||||
PERFORMANCE CHARACTERISTICS:
|
||||
---------------------------
|
||||
- Single user operations: 10-50ms latency (acceptable for financial operations)
|
||||
- Concurrent operations on same user: Serialized (prevents data corruption)
|
||||
- Concurrent operations on different users: Fully parallel (no blocking)
|
||||
|
||||
This design prioritizes correctness and data integrity over raw performance,
|
||||
which is the appropriate choice for a credit/payment system.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
amount (int): The amount of credits to add.
|
||||
@@ -413,142 +345,40 @@ class UserCreditBase(ABC):
|
||||
Returns:
|
||||
tuple[int, str]: The new balance & the transaction key.
|
||||
"""
|
||||
# Quick validation for ceiling balance to avoid unnecessary database operations
|
||||
if ceiling_balance and amount > 0:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
if current_balance >= ceiling_balance:
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}"):
|
||||
# Get latest balance snapshot
|
||||
user_balance, _ = await self._get_credits(user_id)
|
||||
|
||||
if ceiling_balance and amount > 0 and user_balance >= ceiling_balance:
|
||||
raise ValueError(
|
||||
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
f"You already have enough balance of ${user_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
)
|
||||
|
||||
# Single unified atomic operation for all transaction types using UserBalance
|
||||
try:
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
WITH user_balance_lock AS (
|
||||
SELECT
|
||||
$1::text as userId,
|
||||
-- CRITICAL: FOR UPDATE lock prevents concurrent modifications to the same user's balance
|
||||
-- This ensures atomic read-modify-write operations and prevents race conditions
|
||||
COALESCE(
|
||||
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $1 FOR UPDATE),
|
||||
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
|
||||
(SELECT COALESCE(ct."runningBalance", 0)
|
||||
FROM {schema_prefix}"CreditTransaction" ct
|
||||
WHERE ct."userId" = $1
|
||||
AND ct."isActive" = true
|
||||
AND ct."runningBalance" IS NOT NULL
|
||||
ORDER BY ct."createdAt" DESC
|
||||
LIMIT 1),
|
||||
0
|
||||
) as balance
|
||||
),
|
||||
balance_update AS (
|
||||
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
|
||||
SELECT
|
||||
$1::text,
|
||||
CASE
|
||||
-- For inactive transactions: Don't update balance
|
||||
WHEN $5::boolean = false THEN user_balance_lock.balance
|
||||
-- For ceiling balance (amount > 0): Apply ceiling
|
||||
WHEN $2 > 0 AND $7::int IS NOT NULL AND user_balance_lock.balance > $7::int - $2 THEN $7::int
|
||||
-- For regular operations: Apply with overflow/underflow protection
|
||||
WHEN user_balance_lock.balance + $2 > $6::int THEN $6::int
|
||||
WHEN user_balance_lock.balance + $2 < $10::int THEN $10::int
|
||||
ELSE user_balance_lock.balance + $2
|
||||
END,
|
||||
CURRENT_TIMESTAMP
|
||||
FROM user_balance_lock
|
||||
WHERE (
|
||||
$5::boolean = false OR -- Allow inactive transactions
|
||||
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
|
||||
$8::boolean = false OR -- Allow when insufficient balance check is disabled
|
||||
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
|
||||
if amount < 0 and user_balance + amount < 0:
|
||||
if fail_insufficient_credits:
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
user_id=user_id,
|
||||
balance=user_balance,
|
||||
amount=amount,
|
||||
)
|
||||
ON CONFLICT ("userId") DO UPDATE SET
|
||||
"balance" = EXCLUDED."balance",
|
||||
"updatedAt" = EXCLUDED."updatedAt"
|
||||
RETURNING "balance", "updatedAt"
|
||||
),
|
||||
transaction_insert AS (
|
||||
INSERT INTO {schema_prefix}"CreditTransaction" (
|
||||
"userId", "amount", "type", "runningBalance",
|
||||
"metadata", "isActive", "createdAt", "transactionKey"
|
||||
)
|
||||
SELECT
|
||||
$1::text,
|
||||
$2::int,
|
||||
$3::text::{schema_prefix}"CreditTransactionType",
|
||||
CASE
|
||||
-- For inactive transactions: Set runningBalance to original balance (don't apply the change yet)
|
||||
WHEN $5::boolean = false THEN user_balance_lock.balance
|
||||
ELSE COALESCE(balance_update.balance, user_balance_lock.balance)
|
||||
END,
|
||||
$4::jsonb,
|
||||
$5::boolean,
|
||||
COALESCE(balance_update."updatedAt", CURRENT_TIMESTAMP),
|
||||
COALESCE($9, gen_random_uuid()::text)
|
||||
FROM user_balance_lock
|
||||
LEFT JOIN balance_update ON true
|
||||
WHERE (
|
||||
$5::boolean = false OR -- Allow inactive transactions
|
||||
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
|
||||
$8::boolean = false OR -- Allow when insufficient balance check is disabled
|
||||
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
|
||||
)
|
||||
RETURNING "runningBalance", "transactionKey"
|
||||
)
|
||||
SELECT "runningBalance" as balance, "transactionKey" FROM transaction_insert;
|
||||
""",
|
||||
user_id, # $1
|
||||
amount, # $2
|
||||
transaction_type.value, # $3
|
||||
dumps(metadata.data), # $4 - use pre-serialized JSON string for JSONB
|
||||
is_active, # $5
|
||||
POSTGRES_INT_MAX, # $6 - overflow protection
|
||||
ceiling_balance, # $7 - ceiling balance (nullable)
|
||||
fail_insufficient_credits, # $8 - check balance for spending
|
||||
transaction_key, # $9 - transaction key (nullable)
|
||||
POSTGRES_INT_MIN, # $10 - underflow protection
|
||||
)
|
||||
except Exception as e:
|
||||
# Convert raw SQL unique constraint violations to UniqueViolationError
|
||||
# for consistent exception handling throughout the codebase
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"already exists" in error_str
|
||||
or "duplicate key" in error_str
|
||||
or "unique constraint" in error_str
|
||||
):
|
||||
# Extract table and constraint info for better error messages
|
||||
# Re-raise as a UniqueViolationError but with proper format
|
||||
# Create a minimal data structure that the error constructor expects
|
||||
raise UniqueViolationError({"error": str(e), "user_facing_error": {}})
|
||||
# For any other error, re-raise as-is
|
||||
raise
|
||||
|
||||
if result:
|
||||
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
|
||||
# UserBalance is already updated by the CTE
|
||||
return new_balance, tx_key
|
||||
amount = min(-user_balance, 0)
|
||||
|
||||
# If no result, either user doesn't exist or insufficient balance
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
|
||||
# Must be insufficient balance for spending operation
|
||||
if amount < 0 and fail_insufficient_credits:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
user_id=user_id,
|
||||
balance=current_balance,
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
# Unexpected case
|
||||
raise ValueError(f"Transaction failed for user {user_id}, amount {amount}")
|
||||
# Create the transaction
|
||||
transaction_data: CreditTransactionCreateInput = {
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"runningBalance": user_balance + amount,
|
||||
"type": transaction_type,
|
||||
"metadata": metadata,
|
||||
"isActive": is_active,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
if transaction_key:
|
||||
transaction_data["transactionKey"] = transaction_key
|
||||
tx = await CreditTransaction.prisma().create(data=transaction_data)
|
||||
return user_balance + amount, tx.transactionKey
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
@@ -623,10 +453,9 @@ class UserCredit(UserCreditBase):
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
return True
|
||||
except UniqueViolationError:
|
||||
# User already received this reward
|
||||
return False
|
||||
# Already rewarded for this step
|
||||
pass
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
@@ -815,7 +644,7 @@ class UserCredit(UserCreditBase):
|
||||
):
|
||||
# init metadata, without sharing it with the world
|
||||
metadata = metadata or {}
|
||||
if not metadata.get("reason"):
|
||||
if not metadata["reason"]:
|
||||
match top_up_type:
|
||||
case TopUpType.MANUAL:
|
||||
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
|
||||
@@ -1145,8 +974,8 @@ class DisabledUserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def onboarding_reward(self, *args, **kwargs) -> bool:
|
||||
return True
|
||||
async def onboarding_reward(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||
return ""
|
||||
@@ -1164,32 +993,15 @@ class DisabledUserCredit(UserCreditBase):
|
||||
pass
|
||||
|
||||
|
||||
async def get_user_credit_model(user_id: str) -> UserCreditBase:
|
||||
"""
|
||||
Get the credit model for a user, considering LaunchDarkly flags.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to check flags for.
|
||||
|
||||
Returns:
|
||||
UserCreditBase: The appropriate credit model for the user
|
||||
"""
|
||||
def get_user_credit_model() -> UserCreditBase:
|
||||
if not settings.config.enable_credit:
|
||||
return DisabledUserCredit()
|
||||
|
||||
# Check LaunchDarkly flag for payment pilot users
|
||||
# Default to False (beta monthly credit behavior) to maintain current behavior
|
||||
is_payment_enabled = await is_feature_enabled(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
if is_payment_enabled:
|
||||
# Payment enabled users get UserCredit (no monthly refills, enable payments)
|
||||
return UserCredit()
|
||||
else:
|
||||
# Default behavior: users get beta monthly credits
|
||||
if settings.config.enable_beta_monthly_credit:
|
||||
return BetaUserCredit(settings.config.num_user_credits_refill)
|
||||
|
||||
return UserCredit()
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||
@@ -1278,8 +1090,7 @@ async def admin_get_user_history(
|
||||
)
|
||||
reason = metadata.get("reason", "No reason provided")
|
||||
|
||||
user_credit_model = await get_user_credit_model(tx.userId)
|
||||
balance, _ = await user_credit_model._get_credits(tx.userId)
|
||||
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
|
||||
|
||||
history.append(
|
||||
UserTransaction(
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
"""
|
||||
Test ceiling balance functionality to ensure auto top-up limits work correctly.
|
||||
|
||||
This test was added to cover a previously untested code path that could lead to
|
||||
incorrect balance capping behavior.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for ceiling tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their transactions."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ceiling_balance_rejects_when_above_threshold(server: SpinTestServer):
|
||||
"""Test that ceiling balance correctly rejects top-ups when balance is above threshold."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"ceiling-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user balance of 1000 ($10) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
assert current_balance == 1000
|
||||
|
||||
# Try to add 200 more with ceiling of 800 (should reject since 1000 > 800)
|
||||
with pytest.raises(ValueError, match="You already have enough balance"):
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
ceiling_balance=800, # Ceiling lower than current balance
|
||||
)
|
||||
|
||||
# Balance should remain unchanged
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 1000, f"Balance should remain 1000, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ceiling_balance_clamps_when_would_exceed(server: SpinTestServer):
|
||||
"""Test that ceiling balance correctly clamps amounts that would exceed the ceiling."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"ceiling-clamp-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user balance of 500 ($5) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Add 800 more with ceiling of 1000 (should clamp to 1000, not reach 1300)
|
||||
final_balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=800,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
ceiling_balance=1000, # Ceiling should clamp 500 + 800 = 1300 to 1000
|
||||
)
|
||||
|
||||
# Balance should be clamped to ceiling
|
||||
assert (
|
||||
final_balance == 1000
|
||||
), f"Balance should be clamped to 1000, got {final_balance}"
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == 1000
|
||||
), f"Stored balance should be 1000, got {stored_balance}"
|
||||
|
||||
# Verify transaction shows the clamped amount
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": CreditTransactionType.TOP_UP},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
|
||||
# Should have 2 transactions: 500 + (500 to reach ceiling of 1000)
|
||||
assert len(transactions) == 2
|
||||
|
||||
# The second transaction should show it only added 500, not 800
|
||||
second_tx = transactions[0] # Most recent
|
||||
assert second_tx.runningBalance == 1000
|
||||
# The actual amount recorded could be 800 (what was requested) but balance was clamped
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ceiling_balance_allows_when_under_threshold(server: SpinTestServer):
|
||||
"""Test that ceiling balance allows top-ups when balance is under threshold."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"ceiling-under-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user balance of 300 ($3) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=300,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Add 200 more with ceiling of 1000 (should succeed: 300 + 200 = 500 < 1000)
|
||||
final_balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
ceiling_balance=1000,
|
||||
)
|
||||
|
||||
# Balance should be exactly 500
|
||||
assert final_balance == 500, f"Balance should be 500, got {final_balance}"
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == 500
|
||||
), f"Stored balance should be 500, got {stored_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -1,737 +0,0 @@
|
||||
"""
|
||||
Concurrency and atomicity tests for the credit system.
|
||||
|
||||
These tests ensure the credit system handles high-concurrency scenarios correctly
|
||||
without race conditions, deadlocks, or inconsistent state.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
# Test with both UserCredit and BetaUserCredit if needed
|
||||
credit_system = UserCredit()
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user with initial balance."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
# Ensure UserBalance record exists
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their transactions."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_spends_same_user(server: SpinTestServer):
|
||||
"""Test multiple concurrent spends from the same user don't cause race conditions."""
|
||||
user_id = f"concurrent-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user initial balance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Try to spend 10 x $1 concurrently
|
||||
async def spend_one_dollar(idx: int):
|
||||
try:
|
||||
return await credit_system.spend_credits(
|
||||
user_id,
|
||||
100, # $1
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"concurrent-{idx}",
|
||||
reason=f"Concurrent spend {idx}",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return None
|
||||
|
||||
# Run 10 concurrent spends
|
||||
results = await asyncio.gather(
|
||||
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
|
||||
)
|
||||
|
||||
# Count successful spends
|
||||
successful = [
|
||||
r for r in results if r is not None and not isinstance(r, Exception)
|
||||
]
|
||||
failed = [r for r in results if isinstance(r, InsufficientBalanceError)]
|
||||
|
||||
# All 10 should succeed since we have exactly $10
|
||||
assert len(successful) == 10, f"Expected 10 successful, got {len(successful)}"
|
||||
assert len(failed) == 0, f"Expected 0 failures, got {len(failed)}"
|
||||
|
||||
# Final balance should be exactly 0
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
|
||||
|
||||
# Verify transaction history is consistent
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE}
|
||||
)
|
||||
assert (
|
||||
len(transactions) == 10
|
||||
), f"Expected 10 transactions, got {len(transactions)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_spends_insufficient_balance(server: SpinTestServer):
|
||||
"""Test that concurrent spends correctly enforce balance limits."""
|
||||
user_id = f"insufficient-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user limited balance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "limited_balance"}),
|
||||
)
|
||||
|
||||
# Try to spend 10 x $1 concurrently (but only have $5)
|
||||
async def spend_one_dollar(idx: int):
|
||||
try:
|
||||
return await credit_system.spend_credits(
|
||||
user_id,
|
||||
100, # $1
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"insufficient-{idx}",
|
||||
reason=f"Insufficient spend {idx}",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return "FAILED"
|
||||
|
||||
# Run 10 concurrent spends
|
||||
results = await asyncio.gather(
|
||||
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
|
||||
)
|
||||
|
||||
# Count successful vs failed
|
||||
successful = [
|
||||
r
|
||||
for r in results
|
||||
if r not in ["FAILED", None] and not isinstance(r, Exception)
|
||||
]
|
||||
failed = [r for r in results if r == "FAILED"]
|
||||
|
||||
# Exactly 5 should succeed, 5 should fail
|
||||
assert len(successful) == 5, f"Expected 5 successful, got {len(successful)}"
|
||||
assert len(failed) == 5, f"Expected 5 failures, got {len(failed)}"
|
||||
|
||||
# Final balance should be exactly 0
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_mixed_operations(server: SpinTestServer):
|
||||
"""Test concurrent mix of spends, top-ups, and balance checks."""
|
||||
user_id = f"mixed-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Initial balance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Mix of operations
|
||||
async def mixed_operations():
|
||||
operations = []
|
||||
|
||||
# 5 spends of $1 each
|
||||
for i in range(5):
|
||||
operations.append(
|
||||
credit_system.spend_credits(
|
||||
user_id,
|
||||
100,
|
||||
UsageTransactionMetadata(reason=f"Mixed spend {i}"),
|
||||
)
|
||||
)
|
||||
|
||||
# 3 top-ups of $2 each using internal method
|
||||
for i in range(3):
|
||||
operations.append(
|
||||
credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": f"concurrent_topup_{i}"}),
|
||||
)
|
||||
)
|
||||
|
||||
# 10 balance checks
|
||||
for i in range(10):
|
||||
operations.append(credit_system.get_credits(user_id))
|
||||
|
||||
return await asyncio.gather(*operations, return_exceptions=True)
|
||||
|
||||
results = await mixed_operations()
|
||||
|
||||
# Check no exceptions occurred
|
||||
exceptions = [
|
||||
r
|
||||
for r in results
|
||||
if isinstance(r, Exception) and not isinstance(r, InsufficientBalanceError)
|
||||
]
|
||||
assert len(exceptions) == 0, f"Unexpected exceptions: {exceptions}"
|
||||
|
||||
# Final balance should be: 1000 - 500 + 600 = 1100
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 1100, f"Expected balance 1100, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_race_condition_exact_balance(server: SpinTestServer):
|
||||
"""Test spending exact balance amount concurrently doesn't go negative."""
|
||||
user_id = f"exact-balance-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give exact amount using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=100,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "exact_amount"}),
|
||||
)
|
||||
|
||||
# Try to spend $1 twice concurrently
|
||||
async def spend_exact():
|
||||
try:
|
||||
return await credit_system.spend_credits(
|
||||
user_id, 100, UsageTransactionMetadata(reason="Exact spend")
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return "FAILED"
|
||||
|
||||
# Both try to spend the full balance
|
||||
result1, result2 = await asyncio.gather(spend_exact(), spend_exact())
|
||||
|
||||
# Exactly one should succeed
|
||||
results = [result1, result2]
|
||||
successful = [
|
||||
r for r in results if r != "FAILED" and not isinstance(r, Exception)
|
||||
]
|
||||
failed = [r for r in results if r == "FAILED"]
|
||||
|
||||
assert len(successful) == 1, f"Expected 1 success, got {len(successful)}"
|
||||
assert len(failed) == 1, f"Expected 1 failure, got {len(failed)}"
|
||||
|
||||
# Balance should be exactly 0, never negative
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_onboarding_reward_idempotency(server: SpinTestServer):
|
||||
"""Test that onboarding rewards are idempotent (can't be claimed twice)."""
|
||||
user_id = f"onboarding-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Use WELCOME step which is defined in the OnboardingStep enum
|
||||
# Try to claim same reward multiple times concurrently
|
||||
async def claim_reward():
|
||||
try:
|
||||
result = await credit_system.onboarding_reward(
|
||||
user_id, 500, prisma.enums.OnboardingStep.WELCOME
|
||||
)
|
||||
return "SUCCESS" if result else "DUPLICATE"
|
||||
except Exception as e:
|
||||
print(f"Claim reward failed: {e}")
|
||||
return "FAILED"
|
||||
|
||||
# Try 5 concurrent claims of the same reward
|
||||
results = await asyncio.gather(*[claim_reward() for _ in range(5)])
|
||||
|
||||
# Count results
|
||||
success_count = results.count("SUCCESS")
|
||||
failed_count = results.count("FAILED")
|
||||
|
||||
# At least one should succeed, others should be duplicates
|
||||
assert success_count >= 1, f"At least one claim should succeed, got {results}"
|
||||
assert failed_count == 0, f"No claims should fail, got {results}"
|
||||
|
||||
# Check balance - should only have 500, not 2500
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 500, f"Expected balance 500, got {final_balance}"
|
||||
|
||||
# Check only one transaction exists
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"type": prisma.enums.CreditTransactionType.GRANT,
|
||||
"transactionKey": f"REWARD-{user_id}-WELCOME",
|
||||
}
|
||||
)
|
||||
assert (
|
||||
len(transactions) == 1
|
||||
), f"Expected 1 reward transaction, got {len(transactions)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
"""Test that integer overflow is prevented by clamping to POSTGRES_INT_MAX."""
|
||||
user_id = f"overflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Try to add amount that would overflow
|
||||
max_int = POSTGRES_INT_MAX
|
||||
|
||||
# First, set balance near max
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
)
|
||||
|
||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "overflow_protection"}),
|
||||
)
|
||||
|
||||
# Balance should be clamped to max_int, not overflowed
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == max_int
|
||||
), f"Balance should be clamped to {max_int}, got {final_balance}"
|
||||
|
||||
# Verify transaction was created with clamped amount
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"type": prisma.enums.CreditTransactionType.TOP_UP,
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert len(transactions) > 0, "Transaction should be created"
|
||||
assert (
|
||||
transactions[0].runningBalance == max_int
|
||||
), "Transaction should show clamped balance"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_high_concurrency_stress(server: SpinTestServer):
|
||||
"""Stress test with many concurrent operations."""
|
||||
user_id = f"stress-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Initial balance using internal method (bypasses Stripe)
|
||||
initial_balance = 10000 # $100
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=initial_balance,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "stress_test_balance"}),
|
||||
)
|
||||
|
||||
# Run many concurrent operations
|
||||
async def random_operation(idx: int):
|
||||
operation = random.choice(["spend", "check"])
|
||||
|
||||
if operation == "spend":
|
||||
amount = random.randint(1, 50) # $0.01 to $0.50
|
||||
try:
|
||||
return (
|
||||
"spend",
|
||||
amount,
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(reason=f"Stress {idx}"),
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return ("spend_failed", amount, None)
|
||||
else:
|
||||
balance = await credit_system.get_credits(user_id)
|
||||
return ("check", 0, balance)
|
||||
|
||||
# Run 100 concurrent operations
|
||||
results = await asyncio.gather(
|
||||
*[random_operation(i) for i in range(100)], return_exceptions=True
|
||||
)
|
||||
|
||||
# Calculate expected final balance
|
||||
total_spent = sum(
|
||||
r[1]
|
||||
for r in results
|
||||
if not isinstance(r, Exception) and isinstance(r, tuple) and r[0] == "spend"
|
||||
)
|
||||
expected_balance = initial_balance - total_spent
|
||||
|
||||
# Verify final balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == expected_balance
|
||||
), f"Expected {expected_balance}, got {final_balance}"
|
||||
assert final_balance >= 0, "Balance went negative!"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestServer):
|
||||
"""Test multiple concurrent spends when there's sufficient balance for all."""
|
||||
user_id = f"multi-spend-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user 150 balance ($1.50) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=150,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "sufficient_balance"}),
|
||||
)
|
||||
|
||||
# Track individual timing to see serialization
|
||||
timings = {}
|
||||
|
||||
async def spend_with_detailed_timing(amount: int, label: str):
|
||||
start = asyncio.get_event_loop().time()
|
||||
try:
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"concurrent-{label}",
|
||||
reason=f"Concurrent spend {label}",
|
||||
),
|
||||
)
|
||||
end = asyncio.get_event_loop().time()
|
||||
timings[label] = {"start": start, "end": end, "duration": end - start}
|
||||
return f"{label}-SUCCESS"
|
||||
except Exception as e:
|
||||
end = asyncio.get_event_loop().time()
|
||||
timings[label] = {
|
||||
"start": start,
|
||||
"end": end,
|
||||
"duration": end - start,
|
||||
"error": str(e),
|
||||
}
|
||||
return f"{label}-FAILED: {e}"
|
||||
|
||||
# Run concurrent spends: 10, 20, 30 (total 60, well under 150)
|
||||
overall_start = asyncio.get_event_loop().time()
|
||||
results = await asyncio.gather(
|
||||
spend_with_detailed_timing(10, "spend-10"),
|
||||
spend_with_detailed_timing(20, "spend-20"),
|
||||
spend_with_detailed_timing(30, "spend-30"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
overall_end = asyncio.get_event_loop().time()
|
||||
|
||||
print(f"Results: {results}")
|
||||
print(f"Overall duration: {overall_end - overall_start:.4f}s")
|
||||
|
||||
# Analyze timing to detect serialization vs true concurrency
|
||||
print("\nTiming analysis:")
|
||||
for label, timing in timings.items():
|
||||
print(
|
||||
f" {label}: started at {timing['start']:.4f}, ended at {timing['end']:.4f}, duration {timing['duration']:.4f}s"
|
||||
)
|
||||
|
||||
# Check if operations overlapped (true concurrency) or were serialized
|
||||
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
|
||||
print("\nExecution order by start time:")
|
||||
for i, (label, timing) in enumerate(sorted_timings):
|
||||
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
|
||||
|
||||
# Check for overlap (true concurrency) vs serialization
|
||||
overlaps = []
|
||||
for i in range(len(sorted_timings) - 1):
|
||||
current = sorted_timings[i]
|
||||
next_op = sorted_timings[i + 1]
|
||||
if current[1]["end"] > next_op[1]["start"]:
|
||||
overlaps.append(f"{current[0]} overlaps with {next_op[0]}")
|
||||
|
||||
if overlaps:
|
||||
print(f"✅ TRUE CONCURRENCY detected: {overlaps}")
|
||||
else:
|
||||
print("🔒 SERIALIZATION detected: No overlapping execution times")
|
||||
|
||||
# Check final balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
print(f"Final balance: {final_balance}")
|
||||
|
||||
# Count successes/failures
|
||||
successful = [r for r in results if "SUCCESS" in str(r)]
|
||||
failed = [r for r in results if "FAILED" in str(r)]
|
||||
|
||||
print(f"Successful: {len(successful)}, Failed: {len(failed)}")
|
||||
|
||||
# All should succeed since 150 - (10 + 20 + 30) = 90 > 0
|
||||
assert (
|
||||
len(successful) == 3
|
||||
), f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
|
||||
assert final_balance == 90, f"Expected balance 90, got {final_balance}"
|
||||
|
||||
# Check transaction timestamps to confirm database-level serialization
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE},
|
||||
order={"createdAt": "asc"},
|
||||
)
|
||||
print("\nDatabase transaction order (by createdAt):")
|
||||
for i, tx in enumerate(transactions):
|
||||
print(
|
||||
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
|
||||
)
|
||||
|
||||
# Verify running balances are chronologically consistent (ordered by createdAt)
|
||||
actual_balances = [
|
||||
tx.runningBalance for tx in transactions if tx.runningBalance is not None
|
||||
]
|
||||
print(f"Running balances: {actual_balances}")
|
||||
|
||||
# The balances should be valid intermediate states regardless of execution order
|
||||
# Starting balance: 150, spending 10+20+30=60, so final should be 90
|
||||
# The intermediate balances depend on execution order but should all be valid
|
||||
expected_possible_balances = {
|
||||
# If order is 10, 20, 30: [140, 120, 90]
|
||||
# If order is 10, 30, 20: [140, 110, 90]
|
||||
# If order is 20, 10, 30: [130, 120, 90]
|
||||
# If order is 20, 30, 10: [130, 100, 90]
|
||||
# If order is 30, 10, 20: [120, 110, 90]
|
||||
# If order is 30, 20, 10: [120, 100, 90]
|
||||
90,
|
||||
100,
|
||||
110,
|
||||
120,
|
||||
130,
|
||||
140, # All possible intermediate balances
|
||||
}
|
||||
|
||||
# Verify all balances are valid intermediate states
|
||||
for balance in actual_balances:
|
||||
assert (
|
||||
balance in expected_possible_balances
|
||||
), f"Invalid balance {balance}, expected one of {expected_possible_balances}"
|
||||
|
||||
# Final balance should always be 90 (150 - 60)
|
||||
assert (
|
||||
min(actual_balances) == 90
|
||||
), f"Final balance should be 90, got {min(actual_balances)}"
|
||||
|
||||
# The final transaction should always have balance 90
|
||||
# The other transactions should have valid intermediate balances
|
||||
assert (
|
||||
90 in actual_balances
|
||||
), f"Final balance 90 should be in actual_balances: {actual_balances}"
|
||||
|
||||
# All balances should be >= 90 (the final state)
|
||||
assert all(
|
||||
balance >= 90 for balance in actual_balances
|
||||
), f"All balances should be >= 90, got {actual_balances}"
|
||||
|
||||
# CRITICAL: Transactions are atomic but can complete in any order
|
||||
# What matters is that all running balances are valid intermediate states
|
||||
# Each balance should be between 90 (final) and 140 (after first transaction)
|
||||
for balance in actual_balances:
|
||||
assert (
|
||||
90 <= balance <= 140
|
||||
), f"Balance {balance} is outside valid range [90, 140]"
|
||||
|
||||
# Final balance (minimum) should always be 90
|
||||
assert (
|
||||
min(actual_balances) == 90
|
||||
), f"Final balance should be 90, got {min(actual_balances)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_prove_database_locking_behavior(server: SpinTestServer):
|
||||
"""Definitively prove whether database locking causes waiting vs failures."""
|
||||
user_id = f"locking-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set balance to exact amount that can handle all spends using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=60, # Exactly 10+20+30
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "exact_amount_test"}),
|
||||
)
|
||||
|
||||
async def spend_with_precise_timing(amount: int, label: str):
|
||||
request_start = asyncio.get_event_loop().time()
|
||||
db_operation_start = asyncio.get_event_loop().time()
|
||||
try:
|
||||
# Add a small delay to increase chance of true concurrency
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
db_operation_start = asyncio.get_event_loop().time()
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"locking-{label}",
|
||||
reason=f"Locking test {label}",
|
||||
),
|
||||
)
|
||||
db_operation_end = asyncio.get_event_loop().time()
|
||||
|
||||
return {
|
||||
"label": label,
|
||||
"status": "SUCCESS",
|
||||
"request_start": request_start,
|
||||
"db_start": db_operation_start,
|
||||
"db_end": db_operation_end,
|
||||
"db_duration": db_operation_end - db_operation_start,
|
||||
}
|
||||
except Exception as e:
|
||||
db_operation_end = asyncio.get_event_loop().time()
|
||||
return {
|
||||
"label": label,
|
||||
"status": "FAILED",
|
||||
"error": str(e),
|
||||
"request_start": request_start,
|
||||
"db_start": db_operation_start,
|
||||
"db_end": db_operation_end,
|
||||
"db_duration": db_operation_end - db_operation_start,
|
||||
}
|
||||
|
||||
# Launch all requests simultaneously
|
||||
results = await asyncio.gather(
|
||||
spend_with_precise_timing(10, "A"),
|
||||
spend_with_precise_timing(20, "B"),
|
||||
spend_with_precise_timing(30, "C"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
print("\n🔍 LOCKING BEHAVIOR ANALYSIS:")
|
||||
print("=" * 50)
|
||||
|
||||
successful = [
|
||||
r for r in results if isinstance(r, dict) and r.get("status") == "SUCCESS"
|
||||
]
|
||||
failed = [
|
||||
r for r in results if isinstance(r, dict) and r.get("status") == "FAILED"
|
||||
]
|
||||
|
||||
print(f"✅ Successful operations: {len(successful)}")
|
||||
print(f"❌ Failed operations: {len(failed)}")
|
||||
|
||||
if len(failed) > 0:
|
||||
print(
|
||||
"\n🚫 CONCURRENT FAILURES - Some requests failed due to insufficient balance:"
|
||||
)
|
||||
for result in failed:
|
||||
if isinstance(result, dict):
|
||||
print(
|
||||
f" {result['label']}: {result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
if len(successful) == 3:
|
||||
print(
|
||||
"\n🔒 SERIALIZATION CONFIRMED - All requests succeeded, indicating they were queued:"
|
||||
)
|
||||
|
||||
# Sort by actual execution time to see order
|
||||
dict_results = [r for r in results if isinstance(r, dict)]
|
||||
sorted_results = sorted(dict_results, key=lambda x: x["db_start"])
|
||||
|
||||
for i, result in enumerate(sorted_results):
|
||||
print(
|
||||
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
|
||||
)
|
||||
|
||||
# Check if any operations overlapped at the database level
|
||||
print("\n⏱️ Database operation timeline:")
|
||||
for result in sorted_results:
|
||||
print(
|
||||
f" {result['label']}: {result['db_start']:.4f} -> {result['db_end']:.4f}"
|
||||
)
|
||||
|
||||
# Verify final state
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
print(f"\n💰 Final balance: {final_balance}")
|
||||
|
||||
if len(successful) == 3:
|
||||
assert (
|
||||
final_balance == 0
|
||||
), f"If all succeeded, balance should be 0, got {final_balance}"
|
||||
print(
|
||||
"✅ CONCLUSION: Database row locking causes requests to WAIT and execute serially"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"❌ CONCLUSION: Some requests failed, indicating different concurrency behavior"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -1,277 +0,0 @@
|
||||
"""
|
||||
Integration tests for credit system to catch SQL enum casting issues.
|
||||
|
||||
These tests run actual database operations to ensure SQL queries work correctly,
|
||||
which would have caught the CreditTransactionType enum casting bug.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
BetaUserCredit,
|
||||
UsageTransactionMetadata,
|
||||
get_auto_top_up,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def cleanup_test_user():
|
||||
"""Clean up test user data before and after tests."""
|
||||
import uuid
|
||||
|
||||
user_id = str(uuid.uuid4()) # Use unique user ID for each test
|
||||
|
||||
# Create the user first
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# User might already exist, that's fine
|
||||
pass
|
||||
|
||||
yield user_id
|
||||
|
||||
# Cleanup after test
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
# Clear auto-top-up config before deleting user
|
||||
await User.prisma().update(
|
||||
where={"id": user_id}, data={"topUpConfig": SafeJson({})}
|
||||
)
|
||||
await User.prisma().delete(where={"id": user_id})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_credit_transaction_enum_casting_integration(cleanup_test_user):
|
||||
"""
|
||||
Integration test to verify CreditTransactionType enum casting works in SQL queries.
|
||||
|
||||
This test would have caught the enum casting bug where PostgreSQL expected
|
||||
platform."CreditTransactionType" but got "CreditTransactionType".
|
||||
"""
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Test each transaction type to ensure enum casting works
|
||||
test_cases = [
|
||||
(CreditTransactionType.TOP_UP, 100, "Test top-up"),
|
||||
(CreditTransactionType.USAGE, -50, "Test usage"),
|
||||
(CreditTransactionType.GRANT, 200, "Test grant"),
|
||||
(CreditTransactionType.REFUND, -25, "Test refund"),
|
||||
(CreditTransactionType.CARD_CHECK, 0, "Test card check"),
|
||||
]
|
||||
|
||||
for transaction_type, amount, reason in test_cases:
|
||||
metadata = SafeJson({"reason": reason, "test": "enum_casting"})
|
||||
|
||||
# This call would fail with enum casting error before the fix
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=amount,
|
||||
transaction_type=transaction_type,
|
||||
metadata=metadata,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify transaction was created with correct type
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.type == transaction_type
|
||||
assert transaction.amount == amount
|
||||
assert transaction.metadata is not None
|
||||
|
||||
# Verify metadata content
|
||||
assert transaction.metadata["reason"] == reason
|
||||
assert transaction.metadata["test"] == "enum_casting"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auto_top_up_integration(cleanup_test_user, monkeypatch):
|
||||
"""
|
||||
Integration test for auto-top-up functionality that triggers enum casting.
|
||||
|
||||
This tests the complete auto-top-up flow which involves SQL queries with
|
||||
CreditTransactionType enums, ensuring enum casting works end-to-end.
|
||||
"""
|
||||
# Enable credits for this test
|
||||
from backend.data.credit import settings
|
||||
|
||||
monkeypatch.setattr(settings.config, "enable_credit", True)
|
||||
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
|
||||
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
|
||||
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# First add some initial credits so we can test the configuration and subsequent behavior
|
||||
balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=50, # Below threshold that we'll set
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=SafeJson({"reason": "Initial credits before auto top-up config"}),
|
||||
)
|
||||
assert balance == 50
|
||||
|
||||
# Configure auto top-up with threshold above current balance
|
||||
config = AutoTopUpConfig(threshold=100, amount=500)
|
||||
await set_auto_top_up(user_id, config)
|
||||
|
||||
# Verify configuration was saved but no immediate top-up occurred
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
assert current_balance == 50 # Balance should be unchanged
|
||||
|
||||
# Simulate spending credits that would trigger auto top-up
|
||||
# This involves multiple SQL operations with enum casting
|
||||
try:
|
||||
metadata = UsageTransactionMetadata(reason="Test spend to trigger auto top-up")
|
||||
await credit_system.spend_credits(user_id=user_id, cost=10, metadata=metadata)
|
||||
|
||||
# The auto top-up mechanism should have been triggered
|
||||
# Verify the transaction types were handled correctly
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
# Should have at least: GRANT (initial), USAGE (spend), and TOP_UP (auto top-up)
|
||||
assert len(transactions) >= 3
|
||||
|
||||
# Verify different transaction types exist and enum casting worked
|
||||
transaction_types = {t.type for t in transactions}
|
||||
assert CreditTransactionType.GRANT in transaction_types
|
||||
assert CreditTransactionType.USAGE in transaction_types
|
||||
assert (
|
||||
CreditTransactionType.TOP_UP in transaction_types
|
||||
) # Auto top-up should have triggered
|
||||
|
||||
except Exception as e:
|
||||
# If this fails with enum casting error, the test successfully caught the bug
|
||||
if "CreditTransactionType" in str(e) and (
|
||||
"cast" in str(e).lower() or "type" in str(e).lower()
|
||||
):
|
||||
pytest.fail(f"Enum casting error detected: {e}")
|
||||
else:
|
||||
# Re-raise other unexpected errors
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_enum_casting_integration(cleanup_test_user):
|
||||
"""
|
||||
Integration test for _enable_transaction with enum casting.
|
||||
|
||||
Tests the scenario where inactive transactions are enabled, which also
|
||||
involves SQL queries with CreditTransactionType enum casting.
|
||||
"""
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Create an inactive transaction
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=100,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"reason": "Inactive transaction test"}),
|
||||
is_active=False, # Create as inactive
|
||||
)
|
||||
|
||||
# Balance should be 0 since transaction is inactive
|
||||
assert balance == 0
|
||||
|
||||
# Enable the transaction with new metadata
|
||||
enable_metadata = SafeJson(
|
||||
{
|
||||
"payment_method": "test_payment",
|
||||
"activation_reason": "Integration test activation",
|
||||
}
|
||||
)
|
||||
|
||||
# This would fail with enum casting error before the fix
|
||||
final_balance = await credit_system._enable_transaction(
|
||||
transaction_key=tx_key,
|
||||
user_id=user_id,
|
||||
metadata=enable_metadata,
|
||||
)
|
||||
|
||||
# Now balance should reflect the activated transaction
|
||||
assert final_balance == 100
|
||||
|
||||
# Verify transaction was properly enabled with correct enum type
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.isActive is True
|
||||
assert transaction.type == CreditTransactionType.TOP_UP
|
||||
assert transaction.runningBalance == 100
|
||||
|
||||
# Verify metadata was updated
|
||||
assert transaction.metadata is not None
|
||||
assert transaction.metadata["payment_method"] == "test_payment"
|
||||
assert transaction.metadata["activation_reason"] == "Integration test activation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auto_top_up_configuration_storage(cleanup_test_user, monkeypatch):
|
||||
"""
|
||||
Test that auto-top-up configuration is properly stored and retrieved.
|
||||
|
||||
The immediate top-up logic is handled by the API routes, not the core
|
||||
set_auto_top_up function. This test verifies the configuration is correctly
|
||||
saved and can be retrieved.
|
||||
"""
|
||||
# Enable credits for this test
|
||||
from backend.data.credit import settings
|
||||
|
||||
monkeypatch.setattr(settings.config, "enable_credit", True)
|
||||
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
|
||||
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
|
||||
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Set initial balance
|
||||
balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=50,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=SafeJson({"reason": "Initial balance for config test"}),
|
||||
)
|
||||
|
||||
assert balance == 50
|
||||
|
||||
# Configure auto top-up
|
||||
config = AutoTopUpConfig(threshold=100, amount=200)
|
||||
await set_auto_top_up(user_id, config)
|
||||
|
||||
# Verify the configuration was saved
|
||||
retrieved_config = await get_auto_top_up(user_id)
|
||||
assert retrieved_config.threshold == config.threshold
|
||||
assert retrieved_config.amount == config.amount
|
||||
|
||||
# Verify balance is unchanged (no immediate top-up from set_auto_top_up)
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 50 # Should be unchanged
|
||||
|
||||
# Verify no immediate auto-top-up transaction was created by set_auto_top_up
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
# Should only have the initial GRANT transaction
|
||||
assert len(transactions) == 1
|
||||
assert transactions[0].type == CreditTransactionType.GRANT
|
||||
@@ -1,141 +0,0 @@
|
||||
"""
|
||||
Tests for credit system metadata handling to ensure JSON casting works correctly.
|
||||
|
||||
This test verifies that metadata parameters are properly serialized when passed
|
||||
to raw SQL queries with JSONB columns.
|
||||
"""
|
||||
|
||||
# type: ignore
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
|
||||
from backend.data.credit import BetaUserCredit
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user():
|
||||
"""Setup test user and cleanup after test."""
|
||||
user_id = DEFAULT_USER_ID
|
||||
|
||||
# Cleanup before test
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
|
||||
yield user_id
|
||||
|
||||
# Cleanup after test
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_metadata_json_serialization(setup_test_user):
|
||||
"""Test that metadata is properly serialized for JSONB column in raw SQL."""
|
||||
user_id = setup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Test with complex metadata that would fail if not properly serialized
|
||||
complex_metadata = SafeJson(
|
||||
{
|
||||
"graph_exec_id": "test-12345",
|
||||
"reason": "Testing metadata serialization",
|
||||
"nested_data": {
|
||||
"key1": "value1",
|
||||
"key2": ["array", "of", "values"],
|
||||
"key3": {"deeply": {"nested": "object"}},
|
||||
},
|
||||
"special_chars": "Testing 'quotes' and \"double quotes\" and unicode: 🚀",
|
||||
}
|
||||
)
|
||||
|
||||
# This should work without throwing a JSONB casting error
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500, # $5 top-up
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=complex_metadata,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify the transaction was created successfully
|
||||
assert balance == 500
|
||||
|
||||
# Verify the metadata was stored correctly in the database
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.metadata is not None
|
||||
|
||||
# Verify the metadata contains our complex data
|
||||
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
|
||||
assert metadata_dict["graph_exec_id"] == "test-12345"
|
||||
assert metadata_dict["reason"] == "Testing metadata serialization"
|
||||
assert metadata_dict["nested_data"]["key1"] == "value1"
|
||||
assert metadata_dict["nested_data"]["key3"]["deeply"]["nested"] == "object"
|
||||
assert (
|
||||
metadata_dict["special_chars"]
|
||||
== "Testing 'quotes' and \"double quotes\" and unicode: 🚀"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_metadata_serialization(setup_test_user):
|
||||
"""Test that _enable_transaction also handles metadata JSON serialization correctly."""
|
||||
user_id = setup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# First create an inactive transaction
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=300,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"initial": "inactive_transaction"}),
|
||||
is_active=False, # Create as inactive
|
||||
)
|
||||
|
||||
# Initial balance should be 0 because transaction is inactive
|
||||
assert balance == 0
|
||||
|
||||
# Now enable the transaction with new metadata
|
||||
enable_metadata = SafeJson(
|
||||
{
|
||||
"payment_method": "stripe",
|
||||
"payment_intent": "pi_test_12345",
|
||||
"activation_reason": "Payment confirmed",
|
||||
"complex_data": {"array": [1, 2, 3], "boolean": True, "null_value": None},
|
||||
}
|
||||
)
|
||||
|
||||
# This should work without JSONB casting errors
|
||||
final_balance = await credit_system._enable_transaction(
|
||||
transaction_key=tx_key,
|
||||
user_id=user_id,
|
||||
metadata=enable_metadata,
|
||||
)
|
||||
|
||||
# Now balance should reflect the activated transaction
|
||||
assert final_balance == 300
|
||||
|
||||
# Verify the metadata was updated correctly
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.isActive is True
|
||||
|
||||
# Verify the metadata was updated with enable_metadata
|
||||
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
|
||||
assert metadata_dict["payment_method"] == "stripe"
|
||||
assert metadata_dict["payment_intent"] == "pi_test_12345"
|
||||
assert metadata_dict["complex_data"]["array"] == [1, 2, 3]
|
||||
assert metadata_dict["complex_data"]["boolean"] is True
|
||||
assert metadata_dict["complex_data"]["null_value"] is None
|
||||
@@ -1,372 +0,0 @@
|
||||
"""
|
||||
Tests for credit system refund and dispute operations.
|
||||
|
||||
These tests ensure that refund operations (deduct_credits, handle_dispute)
|
||||
are atomic and maintain data consistency.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
credit_system = UserCredit()
|
||||
|
||||
# Test user ID for refund tests
|
||||
REFUND_TEST_USER_ID = "refund-test-user"
|
||||
|
||||
|
||||
async def setup_test_user_with_topup():
|
||||
"""Create a test user with initial balance and a top-up transaction."""
|
||||
# Clean up any existing data
|
||||
await CreditRefundRequest.prisma().delete_many(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
|
||||
|
||||
# Create user
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
}
|
||||
)
|
||||
|
||||
# Create user balance
|
||||
await UserBalance.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
}
|
||||
)
|
||||
|
||||
# Create a top-up transaction that can be refunded
|
||||
topup_tx = await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
}
|
||||
)
|
||||
|
||||
return topup_tx
|
||||
|
||||
|
||||
async def cleanup_test_user():
|
||||
"""Clean up test data."""
|
||||
await CreditRefundRequest.prisma().delete_many(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_deduct_credits_atomic(server: SpinTestServer):
|
||||
"""Test that deduct_credits is atomic and creates transaction correctly."""
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
try:
|
||||
# Create a mock refund object
|
||||
refund = MagicMock(spec=stripe.Refund)
|
||||
refund.id = "re_test_refund_123"
|
||||
refund.payment_intent = topup_tx.transactionKey
|
||||
refund.amount = 500 # Refund $5 of the $10 top-up
|
||||
refund.status = "succeeded"
|
||||
refund.reason = "requested_by_customer"
|
||||
refund.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Create refund request record (simulating webhook flow)
|
||||
await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
}
|
||||
)
|
||||
|
||||
# Call deduct_credits
|
||||
await credit_system.deduct_credits(refund)
|
||||
|
||||
# Verify the user's balance was deducted
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 500
|
||||
), f"Expected balance 500, got {user_balance.balance}"
|
||||
|
||||
# Verify refund transaction was created
|
||||
refund_tx = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"type": CreditTransactionType.REFUND,
|
||||
"transactionKey": refund.id,
|
||||
}
|
||||
)
|
||||
assert refund_tx is not None
|
||||
assert refund_tx.amount == -500
|
||||
assert refund_tx.runningBalance == 500
|
||||
assert refund_tx.isActive
|
||||
|
||||
# Verify refund request was updated
|
||||
refund_request = await CreditRefundRequest.prisma().find_first(
|
||||
where={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
}
|
||||
)
|
||||
assert refund_request is not None
|
||||
assert (
|
||||
refund_request.result
|
||||
== "The refund request has been approved, the amount will be credited back to your account."
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_deduct_credits_user_not_found(server: SpinTestServer):
|
||||
"""Test that deduct_credits raises error if transaction not found (which means user doesn't exist)."""
|
||||
# Create a mock refund object that references a non-existent payment intent
|
||||
refund = MagicMock(spec=stripe.Refund)
|
||||
refund.id = "re_test_refund_nonexistent"
|
||||
refund.payment_intent = "pi_test_nonexistent" # This payment intent doesn't exist
|
||||
refund.amount = 500
|
||||
refund.status = "succeeded"
|
||||
refund.reason = "requested_by_customer"
|
||||
refund.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Should raise error for missing transaction
|
||||
with pytest.raises(Exception): # Should raise NotFoundError for missing transaction
|
||||
await credit_system.deduct_credits(refund)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.data.credit.settings")
|
||||
@patch("stripe.Dispute.modify")
|
||||
@patch("backend.data.credit.get_user_by_id")
|
||||
async def test_handle_dispute_with_sufficient_balance(
|
||||
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
|
||||
):
|
||||
"""Test handling dispute when user has sufficient balance (dispute gets closed)."""
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
try:
|
||||
# Mock settings to have a low tolerance threshold
|
||||
mock_settings.config.refund_credit_tolerance_threshold = 0
|
||||
|
||||
# Mock the user lookup
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Create a mock dispute object for small amount (user has 1000, disputing 100)
|
||||
dispute = MagicMock(spec=stripe.Dispute)
|
||||
dispute.id = "dp_test_dispute_123"
|
||||
dispute.payment_intent = topup_tx.transactionKey
|
||||
dispute.amount = 100 # Small dispute amount
|
||||
dispute.status = "pending"
|
||||
dispute.reason = "fraudulent"
|
||||
dispute.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Mock the close method to prevent real API calls
|
||||
dispute.close = MagicMock()
|
||||
|
||||
# Handle the dispute
|
||||
await credit_system.handle_dispute(dispute)
|
||||
|
||||
# Verify dispute.close() was called (since user has sufficient balance)
|
||||
dispute.close.assert_called_once()
|
||||
|
||||
# Verify no stripe evidence was added since dispute was closed
|
||||
mock_stripe_modify.assert_not_called()
|
||||
|
||||
# Verify the user's balance was NOT deducted (dispute was closed)
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 1000
|
||||
), f"Balance should remain 1000, got {user_balance.balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.data.credit.settings")
|
||||
@patch("stripe.Dispute.modify")
|
||||
@patch("backend.data.credit.get_user_by_id")
|
||||
async def test_handle_dispute_with_insufficient_balance(
|
||||
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
|
||||
):
|
||||
"""Test handling dispute when user has insufficient balance (evidence gets added)."""
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
# Save original method for restoration before any try blocks
|
||||
original_get_history = credit_system.get_transaction_history
|
||||
|
||||
try:
|
||||
# Mock settings to have a high tolerance threshold so dispute isn't closed
|
||||
mock_settings.config.refund_credit_tolerance_threshold = 2000
|
||||
|
||||
# Mock the user lookup
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock the transaction history method to return an async result
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_history = MagicMock()
|
||||
mock_history.transactions = []
|
||||
credit_system.get_transaction_history = AsyncMock(return_value=mock_history)
|
||||
|
||||
# Create a mock dispute object for full amount (user has 1000, disputing 1000)
|
||||
dispute = MagicMock(spec=stripe.Dispute)
|
||||
dispute.id = "dp_test_dispute_pending"
|
||||
dispute.payment_intent = topup_tx.transactionKey
|
||||
dispute.amount = 1000
|
||||
dispute.status = "warning_needs_response"
|
||||
dispute.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Mock the close method to prevent real API calls
|
||||
dispute.close = MagicMock()
|
||||
|
||||
# Handle the dispute (evidence should be added)
|
||||
await credit_system.handle_dispute(dispute)
|
||||
|
||||
# Verify dispute.close() was NOT called (insufficient balance after tolerance)
|
||||
dispute.close.assert_not_called()
|
||||
|
||||
# Verify stripe evidence was added since dispute wasn't closed
|
||||
mock_stripe_modify.assert_called_once()
|
||||
|
||||
# Verify the user's balance was NOT deducted (handle_dispute doesn't deduct credits)
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert user_balance.balance == 1000, "Balance should remain unchanged"
|
||||
|
||||
finally:
|
||||
credit_system.get_transaction_history = original_get_history
|
||||
await cleanup_test_user()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_refunds(server: SpinTestServer):
|
||||
"""Test that concurrent refunds are handled atomically."""
|
||||
import asyncio
|
||||
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
try:
|
||||
# Create multiple refund requests
|
||||
refund_requests = []
|
||||
for i in range(5):
|
||||
req = await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
}
|
||||
)
|
||||
refund_requests.append(req)
|
||||
|
||||
# Create refund tasks to run concurrently
|
||||
async def process_refund(index: int):
|
||||
refund = MagicMock(spec=stripe.Refund)
|
||||
refund.id = f"re_test_concurrent_{index}"
|
||||
refund.payment_intent = topup_tx.transactionKey
|
||||
refund.amount = 100 # $1 refund
|
||||
refund.status = "succeeded"
|
||||
refund.reason = "requested_by_customer"
|
||||
refund.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
try:
|
||||
await credit_system.deduct_credits(refund)
|
||||
return "success"
|
||||
except Exception as e:
|
||||
return f"error: {e}"
|
||||
|
||||
# Run refunds concurrently
|
||||
results = await asyncio.gather(
|
||||
*[process_refund(i) for i in range(5)], return_exceptions=True
|
||||
)
|
||||
|
||||
# All should succeed
|
||||
assert all(r == "success" for r in results), f"Some refunds failed: {results}"
|
||||
|
||||
# Verify final balance - with non-atomic implementation, this will demonstrate race condition
|
||||
# EXPECTED BEHAVIOR: Due to race conditions, not all refunds will be properly processed
|
||||
# The balance will be incorrect (higher than expected) showing lost updates
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
|
||||
# With atomic implementation, this should be 500 (1000 - 5*100)
|
||||
# With current non-atomic implementation, this will likely be wrong due to race conditions
|
||||
print(f"DEBUG: Final balance = {user_balance.balance}, expected = 500")
|
||||
|
||||
# With atomic implementation, all 5 refunds should process correctly
|
||||
assert (
|
||||
user_balance.balance == 500
|
||||
), f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
|
||||
|
||||
# Verify all refund transactions exist
|
||||
refund_txs = await CreditTransaction.prisma().find_many(
|
||||
where={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"type": CreditTransactionType.REFUND,
|
||||
}
|
||||
)
|
||||
assert (
|
||||
len(refund_txs) == 5
|
||||
), f"Expected 5 refund transactions, got {len(refund_txs)}"
|
||||
|
||||
running_balances: set[int] = {
|
||||
tx.runningBalance for tx in refund_txs if tx.runningBalance is not None
|
||||
}
|
||||
|
||||
# Verify all balances are valid intermediate states
|
||||
for balance in running_balances:
|
||||
assert (
|
||||
500 <= balance <= 1000
|
||||
), f"Invalid balance {balance}, should be between 500 and 1000"
|
||||
|
||||
# Final balance should be present
|
||||
assert (
|
||||
500 in running_balances
|
||||
), f"Final balance 500 should be in {running_balances}"
|
||||
|
||||
# All balances should be unique and form a valid sequence
|
||||
sorted_balances = sorted(running_balances, reverse=True)
|
||||
assert (
|
||||
len(sorted_balances) == 5
|
||||
), f"Expected 5 unique balances, got {len(sorted_balances)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
@@ -1,8 +1,8 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
from prisma.models import CreditTransaction
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
@@ -19,24 +19,14 @@ user_credit = BetaUserCredit(REFILL_VALUE)
|
||||
|
||||
async def disable_test_user_transactions():
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": DEFAULT_USER_ID})
|
||||
# Also reset the balance to 0 and set updatedAt to old date to trigger monthly refill
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def top_up(amount: int):
|
||||
balance, _ = await user_credit._add_transaction(
|
||||
await user_credit._add_transaction(
|
||||
DEFAULT_USER_ID,
|
||||
amount,
|
||||
CreditTransactionType.TOP_UP,
|
||||
)
|
||||
return balance
|
||||
|
||||
|
||||
async def spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
@@ -121,90 +111,29 @@ async def test_block_credit_top_up(server: SpinTestServer):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_credit_reset(server: SpinTestServer):
|
||||
"""Test that BetaUserCredit provides monthly refills correctly."""
|
||||
await disable_test_user_transactions()
|
||||
month1 = 1
|
||||
month2 = 2
|
||||
|
||||
# Save original time_now function for restoration
|
||||
original_time_now = user_credit.time_now
|
||||
# set the calendar to month 2 but use current time from now
|
||||
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
|
||||
month=month2, day=1
|
||||
)
|
||||
month2credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
|
||||
try:
|
||||
# Test month 1 behavior
|
||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||
user_credit.time_now = lambda: month1
|
||||
# Month 1 result should only affect month 1
|
||||
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
|
||||
month=month1, day=1
|
||||
)
|
||||
month1credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
await top_up(100)
|
||||
assert await user_credit.get_credits(DEFAULT_USER_ID) == month1credit + 100
|
||||
|
||||
# First call in month 1 should trigger refill
|
||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||
|
||||
# Manually create a transaction with month 1 timestamp to establish history
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
}
|
||||
)
|
||||
|
||||
# Update user balance to match
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
)
|
||||
|
||||
# Now test month 2 behavior
|
||||
month2 = datetime.now(timezone.utc).replace(month=2, day=1)
|
||||
user_credit.time_now = lambda: month2
|
||||
|
||||
# In month 2, since balance (1100) > refill (1000), no refill should happen
|
||||
month2_balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert month2_balance == 1100 # Balance persists, no reset
|
||||
|
||||
# Now test the refill behavior when balance is low
|
||||
# Set balance below refill threshold
|
||||
await UserBalance.prisma().update(
|
||||
where={"userId": DEFAULT_USER_ID}, data={"balance": 400}
|
||||
)
|
||||
|
||||
# Create a month 2 transaction to update the last transaction time
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
}
|
||||
)
|
||||
|
||||
# Move to month 3
|
||||
month3 = datetime.now(timezone.utc).replace(month=3, day=1)
|
||||
user_credit.time_now = lambda: month3
|
||||
|
||||
# Should get refilled since balance (400) < refill value (1000)
|
||||
month3_balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert month3_balance == REFILL_VALUE # Should be refilled to 1000
|
||||
|
||||
# Verify the refill transaction was created
|
||||
refill_tx = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"type": CreditTransactionType.GRANT,
|
||||
"transactionKey": {"contains": "MONTHLY-CREDIT-TOP-UP"},
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert refill_tx is not None, "Monthly refill transaction should be created"
|
||||
assert refill_tx.amount == 600, "Refill should be 600 (1000 - 400)"
|
||||
finally:
|
||||
# Restore original time_now function
|
||||
user_credit.time_now = original_time_now
|
||||
# Month 2 balance is unaffected
|
||||
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
|
||||
month=month2, day=1
|
||||
)
|
||||
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -1,361 +0,0 @@
|
||||
"""
|
||||
Test underflow protection for cumulative refunds and negative transactions.
|
||||
|
||||
This test ensures that when multiple large refunds are processed, the user balance
|
||||
doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for underflow tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their transactions."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
"""Debug underflow behavior step by step."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"debug-underflow-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
print(f"POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
|
||||
|
||||
# Test 1: Set up balance close to underflow threshold
|
||||
print("\n=== Test 1: Setting up balance close to underflow threshold ===")
|
||||
# First, manually set balance to a value very close to POSTGRES_INT_MIN
|
||||
# We'll set it to POSTGRES_INT_MIN + 100, then try to subtract 200
|
||||
# This should trigger underflow protection: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
|
||||
initial_balance_target = POSTGRES_INT_MIN + 100
|
||||
|
||||
# Use direct database update to set the balance close to underflow
|
||||
from prisma.models import UserBalance
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
print(f"Set balance to: {current_balance}")
|
||||
assert current_balance == initial_balance_target
|
||||
|
||||
# Test 2: Apply amount that should cause underflow
|
||||
print("\n=== Test 2: Testing underflow protection ===")
|
||||
test_amount = (
|
||||
-200
|
||||
) # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
|
||||
expected_without_protection = current_balance + test_amount
|
||||
print(f"Current balance: {current_balance}")
|
||||
print(f"Test amount: {test_amount}")
|
||||
print(f"Without protection would be: {expected_without_protection}")
|
||||
print(f"Should be clamped to POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
|
||||
|
||||
# Apply the amount that should trigger underflow protection
|
||||
balance_result, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=test_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
print(f"Actual result: {balance_result}")
|
||||
|
||||
# Check if underflow protection worked
|
||||
assert (
|
||||
balance_result == POSTGRES_INT_MIN
|
||||
), f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
|
||||
|
||||
# Test 3: Edge case - exactly at POSTGRES_INT_MIN
|
||||
print("\n=== Test 3: Testing exact POSTGRES_INT_MIN boundary ===")
|
||||
# Set balance to exactly POSTGRES_INT_MIN
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
)
|
||||
|
||||
edge_balance = await credit_system.get_credits(user_id)
|
||||
print(f"Balance set to exactly POSTGRES_INT_MIN: {edge_balance}")
|
||||
|
||||
# Try to subtract 1 - should stay at POSTGRES_INT_MIN
|
||||
edge_result, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-1,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
print(f"After subtracting 1: {edge_result}")
|
||||
|
||||
assert (
|
||||
edge_result == POSTGRES_INT_MIN
|
||||
), f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
"""Test that large cumulative refunds don't cause integer underflow."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"underflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold to test the protection
|
||||
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
|
||||
# This should trigger underflow protection
|
||||
from prisma.models import UserBalance
|
||||
|
||||
test_balance = POSTGRES_INT_MIN + 1000
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
assert current_balance == test_balance
|
||||
|
||||
# Try to deduct amount that would cause underflow: test_balance + (-2000) = POSTGRES_INT_MIN - 1000
|
||||
underflow_amount = -2000
|
||||
expected_without_protection = (
|
||||
current_balance + underflow_amount
|
||||
) # Should be POSTGRES_INT_MIN - 1000
|
||||
|
||||
# Use _add_transaction directly with amount that would cause underflow
|
||||
final_balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=underflow_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False, # Allow going negative for refunds
|
||||
)
|
||||
|
||||
# Balance should be clamped to POSTGRES_INT_MIN, not the calculated underflow value
|
||||
assert (
|
||||
final_balance == POSTGRES_INT_MIN
|
||||
), f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
assert (
|
||||
final_balance > expected_without_protection
|
||||
), f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == POSTGRES_INT_MIN
|
||||
), f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
|
||||
|
||||
# Verify transaction was created with the underflow-protected balance
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": CreditTransactionType.REFUND},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert len(transactions) > 0, "Refund transaction should be created"
|
||||
assert (
|
||||
transactions[0].runningBalance == POSTGRES_INT_MIN
|
||||
), f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServer):
|
||||
"""Test that multiple large refunds applied sequentially don't cause underflow."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"cumulative-underflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold
|
||||
from prisma.models import UserBalance
|
||||
|
||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
)
|
||||
|
||||
# Apply multiple refunds that would cumulatively underflow
|
||||
refund_amount = -300 # Each refund that would cause underflow when cumulative
|
||||
|
||||
# First refund: (POSTGRES_INT_MIN + 500) + (-300) = POSTGRES_INT_MIN + 200 (still above minimum)
|
||||
balance_1, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=refund_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
|
||||
# Should be above minimum for first refund
|
||||
expected_balance_1 = (
|
||||
initial_balance + refund_amount
|
||||
) # Should be POSTGRES_INT_MIN + 200
|
||||
assert (
|
||||
balance_1 == expected_balance_1
|
||||
), f"First refund should result in {expected_balance_1}, got {balance_1}"
|
||||
assert (
|
||||
balance_1 >= POSTGRES_INT_MIN
|
||||
), f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
|
||||
|
||||
# Second refund: (POSTGRES_INT_MIN + 200) + (-300) = POSTGRES_INT_MIN - 100 (would underflow)
|
||||
balance_2, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=refund_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
|
||||
# Should be clamped to minimum due to underflow protection
|
||||
assert (
|
||||
balance_2 == POSTGRES_INT_MIN
|
||||
), f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
|
||||
|
||||
# Third refund: Should stay at minimum
|
||||
balance_3, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=refund_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
|
||||
# Should still be at minimum
|
||||
assert (
|
||||
balance_3 == POSTGRES_INT_MIN
|
||||
), f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
|
||||
|
||||
# Final balance check
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == POSTGRES_INT_MIN
|
||||
), f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
||||
"""Test that concurrent large refunds don't cause race condition underflow."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"concurrent-underflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold
|
||||
from prisma.models import UserBalance
|
||||
|
||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
)
|
||||
|
||||
async def large_refund(amount: int, label: str):
|
||||
try:
|
||||
return await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
except Exception as e:
|
||||
return f"FAILED-{label}: {e}"
|
||||
|
||||
# Run concurrent refunds that would cause underflow if not protected
|
||||
# Each refund of 500 would cause underflow: initial_balance + (-500) could go below POSTGRES_INT_MIN
|
||||
refund_amount = 500
|
||||
results = await asyncio.gather(
|
||||
large_refund(refund_amount, "A"),
|
||||
large_refund(refund_amount, "B"),
|
||||
large_refund(refund_amount, "C"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Check all results are valid and no underflow occurred
|
||||
valid_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, tuple):
|
||||
balance, _ = result
|
||||
assert (
|
||||
balance >= POSTGRES_INT_MIN
|
||||
), f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
|
||||
valid_results.append(balance)
|
||||
elif isinstance(result, str) and "FAILED" in result:
|
||||
# Some operations might fail due to validation, that's okay
|
||||
pass
|
||||
else:
|
||||
# Unexpected exception
|
||||
assert not isinstance(
|
||||
result, Exception
|
||||
), f"Unexpected exception in result {i}: {result}"
|
||||
|
||||
# At least one operation should succeed
|
||||
assert (
|
||||
len(valid_results) > 0
|
||||
), f"At least one refund should succeed, got results: {results}"
|
||||
|
||||
# All successful results should be >= POSTGRES_INT_MIN
|
||||
for balance in valid_results:
|
||||
assert (
|
||||
balance >= POSTGRES_INT_MIN
|
||||
), f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
|
||||
|
||||
# Final balance should be valid and at or above POSTGRES_INT_MIN
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance >= POSTGRES_INT_MIN
|
||||
), f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -1,217 +0,0 @@
|
||||
"""
|
||||
Integration test to verify complete migration from User.balance to UserBalance table.
|
||||
|
||||
This test ensures that:
|
||||
1. No User.balance queries exist in the system
|
||||
2. All balance operations go through UserBalance table
|
||||
3. User and UserBalance stay synchronized properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for migration tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their data."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_user_balance_migration_complete(server: SpinTestServer):
|
||||
"""Test that User table balance is never used and UserBalance is source of truth."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"migration-test-{datetime.now().timestamp()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# 1. Verify User table does NOT have balance set initially
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
assert user is not None
|
||||
# User.balance should not exist or should be None/0 if it exists
|
||||
user_balance_attr = getattr(user, "balance", None)
|
||||
if user_balance_attr is not None:
|
||||
assert (
|
||||
user_balance_attr == 0 or user_balance_attr is None
|
||||
), f"User.balance should be 0 or None, got {user_balance_attr}"
|
||||
|
||||
# 2. Perform various credit operations using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "migration_test"}),
|
||||
)
|
||||
balance1 = await credit_system.get_credits(user_id)
|
||||
assert balance1 == 1000
|
||||
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
300,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id="test", reason="Migration test spend"
|
||||
),
|
||||
)
|
||||
balance2 = await credit_system.get_credits(user_id)
|
||||
assert balance2 == 700
|
||||
|
||||
# 3. Verify UserBalance table has correct values
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 700
|
||||
), f"UserBalance should be 700, got {user_balance.balance}"
|
||||
|
||||
# 4. CRITICAL: Verify User.balance is NEVER updated during operations
|
||||
user_after = await User.prisma().find_unique(where={"id": user_id})
|
||||
assert user_after is not None
|
||||
user_balance_after = getattr(user_after, "balance", None)
|
||||
if user_balance_after is not None:
|
||||
# If User.balance exists, it should still be 0 (never updated)
|
||||
assert (
|
||||
user_balance_after == 0 or user_balance_after is None
|
||||
), f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
|
||||
|
||||
# 5. Verify get_credits always returns UserBalance value, not User.balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == user_balance.balance
|
||||
), f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
||||
"""Test to detect if any operations are still using User.balance instead of UserBalance."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"stale-query-test-{datetime.now().timestamp()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Create UserBalance with specific value
|
||||
await UserBalance.prisma().create(
|
||||
data={"userId": user_id, "balance": 5000} # $50
|
||||
)
|
||||
|
||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||
balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
balance == 5000
|
||||
), f"Expected get_credits to return 5000 from UserBalance, got {balance}"
|
||||
|
||||
# Verify all operations use UserBalance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "final_verification"}),
|
||||
)
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 6000, f"Expected 6000, got {final_balance}"
|
||||
|
||||
# Verify UserBalance table has the correct value
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 6000
|
||||
), f"UserBalance should be 6000, got {user_balance.balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer):
|
||||
"""Test that concurrent operations all use UserBalance locking, not User.balance."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"concurrent-userbalance-test-{datetime.now().timestamp()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set initial balance in UserBalance
|
||||
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
||||
|
||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||
async def concurrent_spend(amount: int, label: str):
|
||||
try:
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"concurrent-{label}",
|
||||
reason=f"Concurrent test {label}",
|
||||
),
|
||||
)
|
||||
return f"{label}-SUCCESS"
|
||||
except Exception as e:
|
||||
return f"{label}-FAILED: {e}"
|
||||
|
||||
# Run concurrent operations
|
||||
results = await asyncio.gather(
|
||||
concurrent_spend(100, "A"),
|
||||
concurrent_spend(200, "B"),
|
||||
concurrent_spend(300, "C"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# All should succeed (1000 >= 100+200+300)
|
||||
successful = [r for r in results if "SUCCESS" in str(r)]
|
||||
assert len(successful) == 3, f"All operations should succeed, got {results}"
|
||||
|
||||
# Final balance should be 1000 - 600 = 400
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 400, f"Expected final balance 400, got {final_balance}"
|
||||
|
||||
# Verify UserBalance has correct value
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 400
|
||||
), f"UserBalance should be 400, got {user_balance.balance}"
|
||||
|
||||
# Critical: If User.balance exists and was used, it might have wrong value
|
||||
try:
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
user_balance_attr = getattr(user, "balance", None)
|
||||
if user_balance_attr is not None:
|
||||
# If User.balance exists, it should NOT be used for operations
|
||||
# The fact that our final balance is correct from UserBalance proves the system is working
|
||||
print(
|
||||
f"✅ User.balance exists ({user_balance_attr}) but UserBalance ({user_balance.balance}) is being used correctly"
|
||||
)
|
||||
except Exception:
|
||||
print("✅ User.balance column doesn't exist - migration is complete")
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -98,6 +98,42 @@ async def transaction(timeout: int = TRANSACTION_TIMEOUT):
|
||||
yield tx
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def locked_transaction(key: str, timeout: int = TRANSACTION_TIMEOUT):
|
||||
"""
|
||||
Create a transaction and take a per-key advisory *transaction* lock.
|
||||
|
||||
- Uses a 64-bit lock id via hashtextextended(key, 0) to avoid 32-bit collisions.
|
||||
- Bound by lock_timeout and statement_timeout so it won't block indefinitely.
|
||||
- Lock is held for the duration of the transaction and auto-released on commit/rollback.
|
||||
|
||||
Args:
|
||||
key: String lock key (e.g., "usr_trx_<uuid>").
|
||||
timeout: Transaction/lock/statement timeout in milliseconds.
|
||||
"""
|
||||
async with transaction(timeout=timeout) as tx:
|
||||
# Ensure we don't wait longer than desired
|
||||
# Note: SET LOCAL doesn't support parameterized queries, must use string interpolation
|
||||
await tx.execute_raw(f"SET LOCAL statement_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
|
||||
await tx.execute_raw(f"SET LOCAL lock_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
|
||||
|
||||
# Block until acquired or lock_timeout hits
|
||||
try:
|
||||
await tx.execute_raw(
|
||||
"SELECT pg_advisory_xact_lock(hashtextextended($1, 0))",
|
||||
key,
|
||||
)
|
||||
except Exception as e:
|
||||
# Normalize PG's lock timeout error to TimeoutError for callers
|
||||
if "lock timeout" in str(e).lower():
|
||||
raise TimeoutError(
|
||||
f"Could not acquire lock for key={key!r} within {timeout}ms"
|
||||
) from e
|
||||
raise
|
||||
|
||||
yield tx
|
||||
|
||||
|
||||
def get_database_schema() -> str:
|
||||
"""Extract database schema from DATABASE_URL."""
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
|
||||
@@ -38,8 +38,8 @@ from prisma.types import (
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
@@ -478,48 +478,6 @@ async def get_graph_executions(
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
async def get_graph_executions_count(
|
||||
user_id: Optional[str] = None,
|
||||
graph_id: Optional[str] = None,
|
||||
statuses: Optional[list[ExecutionStatus]] = None,
|
||||
created_time_gte: Optional[datetime] = None,
|
||||
created_time_lte: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get count of graph executions with optional filters.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter by
|
||||
graph_id: Optional graph ID to filter by
|
||||
statuses: Optional list of execution statuses to filter by
|
||||
created_time_gte: Optional minimum creation time
|
||||
created_time_lte: Optional maximum creation time
|
||||
|
||||
Returns:
|
||||
Count of matching graph executions
|
||||
"""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
|
||||
if user_id:
|
||||
where_filter["userId"] = user_id
|
||||
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
|
||||
if created_time_gte or created_time_lte:
|
||||
where_filter["createdAt"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
if statuses:
|
||||
where_filter["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
count = await AgentGraphExecution.prisma().count(where=where_filter)
|
||||
return count
|
||||
|
||||
|
||||
class GraphExecutionsPaginated(BaseModel):
|
||||
"""Response schema for paginated graph executions."""
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from prisma.enums import AgentExecutionStatus
|
||||
from backend.data.execution import get_graph_executions
|
||||
from backend.data.graph import get_graph_metadata
|
||||
from backend.data.model import UserExecutionSummaryStats
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[SummaryData]")
|
||||
|
||||
@@ -129,20 +129,17 @@ class NodeModel(Node):
|
||||
Returns a copy of the node model, stripped of any non-transferable properties
|
||||
"""
|
||||
stripped_node = self.model_copy(deep=True)
|
||||
|
||||
# Remove credentials and other (possible) secrets from node input
|
||||
# Remove credentials from node input
|
||||
if stripped_node.input_default:
|
||||
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
|
||||
stripped_node.input_default, self.block.input_schema.jsonschema()
|
||||
)
|
||||
|
||||
# Remove default secret value from secret input nodes
|
||||
if (
|
||||
stripped_node.block.block_type == BlockType.INPUT
|
||||
and stripped_node.input_default.get("secret", False) is True
|
||||
and "value" in stripped_node.input_default
|
||||
):
|
||||
del stripped_node.input_default["value"]
|
||||
stripped_node.input_default["value"] = ""
|
||||
|
||||
# Remove webhook info
|
||||
stripped_node.webhook_id = None
|
||||
@@ -159,10 +156,8 @@ class NodeModel(Node):
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
field_schema: dict | None = field_schemas.get(key)
|
||||
if (field_schema and field_schema.get("secret", False)) or (
|
||||
any(sensitive_key in key.lower() for sensitive_key in sensitive_keys)
|
||||
# Prevent removing `secret` flag on input nodes
|
||||
and type(value) is not bool
|
||||
if (field_schema and field_schema.get("secret", False)) or any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
):
|
||||
# This is a secret value -> filter this key-value pair out
|
||||
continue
|
||||
|
||||
@@ -201,56 +201,25 @@ async def test_get_input_schema(server: SpinTestServer, snapshot: Snapshot):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clean_graph(server: SpinTestServer):
|
||||
"""
|
||||
Test the stripped_for_export function that:
|
||||
1. Removes sensitive/secret fields from node inputs
|
||||
2. Removes webhook information
|
||||
3. Preserves non-sensitive data including input block values
|
||||
Test the clean_graph function that:
|
||||
1. Clears input block values
|
||||
2. Removes credentials from nodes
|
||||
"""
|
||||
# Create a graph with input blocks containing both sensitive and normal data
|
||||
# Create a graph with input blocks and credentials
|
||||
graph = Graph(
|
||||
id="test_clean_graph",
|
||||
name="Test Clean Graph",
|
||||
description="Test graph cleaning",
|
||||
nodes=[
|
||||
Node(
|
||||
id="input_node",
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={
|
||||
"_test_id": "input_node",
|
||||
"name": "test_input",
|
||||
"value": "test value", # This should be preserved
|
||||
"value": "test value",
|
||||
"description": "Test input description",
|
||||
},
|
||||
),
|
||||
Node(
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={
|
||||
"_test_id": "input_node_secret",
|
||||
"name": "secret_input",
|
||||
"value": "another value",
|
||||
"secret": True, # This makes the input secret
|
||||
},
|
||||
),
|
||||
Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"_test_id": "node_with_secrets",
|
||||
"input": "normal_value",
|
||||
"control_test_input": "should be preserved",
|
||||
"api_key": "secret_api_key_123", # Should be filtered
|
||||
"password": "secret_password_456", # Should be filtered
|
||||
"token": "secret_token_789", # Should be filtered
|
||||
"credentials": { # Should be filtered
|
||||
"id": "fake-github-credentials-id",
|
||||
"provider": "github",
|
||||
"type": "api_key",
|
||||
},
|
||||
"anthropic_credentials": { # Should be filtered
|
||||
"id": "fake-anthropic-credentials-id",
|
||||
"provider": "anthropic",
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
@@ -262,54 +231,15 @@ async def test_clean_graph(server: SpinTestServer):
|
||||
)
|
||||
|
||||
# Clean the graph
|
||||
cleaned_graph = await server.agent_server.test_get_graph(
|
||||
created_graph = await server.agent_server.test_get_graph(
|
||||
created_graph.id, created_graph.version, DEFAULT_USER_ID, for_export=True
|
||||
)
|
||||
|
||||
# Verify sensitive fields are removed but normal fields are preserved
|
||||
# # Verify input block value is cleared
|
||||
input_node = next(
|
||||
n for n in cleaned_graph.nodes if n.input_default["_test_id"] == "input_node"
|
||||
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
|
||||
)
|
||||
|
||||
# Non-sensitive fields should be preserved
|
||||
assert input_node.input_default["name"] == "test_input"
|
||||
assert input_node.input_default["value"] == "test value" # Should be preserved now
|
||||
assert input_node.input_default["description"] == "Test input description"
|
||||
|
||||
# Sensitive fields should be filtered out
|
||||
assert "api_key" not in input_node.input_default
|
||||
assert "password" not in input_node.input_default
|
||||
|
||||
# Verify secret input node preserves non-sensitive fields but removes secret value
|
||||
secret_node = next(
|
||||
n
|
||||
for n in cleaned_graph.nodes
|
||||
if n.input_default["_test_id"] == "input_node_secret"
|
||||
)
|
||||
assert secret_node.input_default["name"] == "secret_input"
|
||||
assert "value" not in secret_node.input_default # Secret default should be removed
|
||||
assert secret_node.input_default["secret"] is True
|
||||
|
||||
# Verify sensitive fields are filtered from nodes with secrets
|
||||
secrets_node = next(
|
||||
n
|
||||
for n in cleaned_graph.nodes
|
||||
if n.input_default["_test_id"] == "node_with_secrets"
|
||||
)
|
||||
# Normal fields should be preserved
|
||||
assert secrets_node.input_default["input"] == "normal_value"
|
||||
assert secrets_node.input_default["control_test_input"] == "should be preserved"
|
||||
# Sensitive fields should be filtered out
|
||||
assert "api_key" not in secrets_node.input_default
|
||||
assert "password" not in secrets_node.input_default
|
||||
assert "token" not in secrets_node.input_default
|
||||
assert "credentials" not in secrets_node.input_default
|
||||
assert "anthropic_credentials" not in secrets_node.input_default
|
||||
|
||||
# Verify webhook info is removed (if any nodes had it)
|
||||
for node in cleaned_graph.nodes:
|
||||
assert node.webhook_id is None
|
||||
assert node.webhook is None
|
||||
assert input_node.input_default["value"] == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -347,9 +347,6 @@ class APIKeyCredentials(_BaseCredentials):
|
||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||
|
||||
def auth_header(self) -> str:
|
||||
# Linear API keys should not have Bearer prefix
|
||||
if self.provider == "linear":
|
||||
return self.api_key.get_secret_value()
|
||||
return f"Bearer {self.api_key.get_secret_value()}"
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from prisma.types import (
|
||||
# from backend.notifications.models import NotificationEvent
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.logging import TruncatedLogger
|
||||
|
||||
@@ -235,7 +235,6 @@ class BaseEventModel(BaseModel):
|
||||
|
||||
|
||||
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
|
||||
id: Optional[str] = None # None when creating, populated when reading from DB
|
||||
data: NotificationDataType_co
|
||||
|
||||
@property
|
||||
@@ -379,7 +378,6 @@ class NotificationPreference(BaseModel):
|
||||
|
||||
|
||||
class UserNotificationEventDTO(BaseModel):
|
||||
id: str # Added to track notifications for removal
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime
|
||||
@@ -388,7 +386,6 @@ class UserNotificationEventDTO(BaseModel):
|
||||
@staticmethod
|
||||
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
|
||||
return UserNotificationEventDTO(
|
||||
id=model.id,
|
||||
type=model.type,
|
||||
data=dict(model.data),
|
||||
created_at=model.createdAt,
|
||||
@@ -544,79 +541,6 @@ async def empty_user_notification_batch(
|
||||
) from e
|
||||
|
||||
|
||||
async def clear_all_user_notification_batches(user_id: str) -> None:
|
||||
"""Clear ALL notification batches for a user across all types.
|
||||
|
||||
Used when user's email is bounced/inactive and we should stop
|
||||
trying to send them ANY emails.
|
||||
"""
|
||||
try:
|
||||
# Delete all notification events for this user
|
||||
await NotificationEvent.prisma().delete_many(
|
||||
where={"UserNotificationBatch": {"is": {"userId": user_id}}}
|
||||
)
|
||||
|
||||
# Delete all batches for this user
|
||||
await UserNotificationBatch.prisma().delete_many(where={"userId": user_id})
|
||||
|
||||
logger.info(f"Cleared all notification batches for user {user_id}")
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to clear all notification batches for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def remove_notifications_from_batch(
|
||||
user_id: str, notification_type: NotificationType, notification_ids: list[str]
|
||||
) -> None:
|
||||
"""Remove specific notifications from a user's batch by their IDs.
|
||||
|
||||
This is used after successful sending to remove only the
|
||||
sent notifications, preventing duplicates on retry.
|
||||
"""
|
||||
if not notification_ids:
|
||||
return
|
||||
|
||||
try:
|
||||
# Delete the specific notification events
|
||||
deleted_count = await NotificationEvent.prisma().delete_many(
|
||||
where={
|
||||
"id": {"in": notification_ids},
|
||||
"UserNotificationBatch": {
|
||||
"is": {"userId": user_id, "type": notification_type}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Removed {deleted_count} notifications from batch for user {user_id}"
|
||||
)
|
||||
|
||||
# Check if batch is now empty and delete it if so
|
||||
remaining = await NotificationEvent.prisma().count(
|
||||
where={
|
||||
"UserNotificationBatch": {
|
||||
"is": {"userId": user_id, "type": notification_type}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if remaining == 0:
|
||||
await UserNotificationBatch.prisma().delete_many(
|
||||
where=UserNotificationBatchWhereInput(
|
||||
userId=user_id,
|
||||
type=notification_type,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Deleted empty batch for user {user_id} and type {notification_type}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to remove notifications from batch for user {user_id} and type {notification_type}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
|
||||
@@ -4,15 +4,16 @@ from typing import Any, Optional
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
|
||||
from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
# Mapping from user reason id to categories to search for when choosing agent to show
|
||||
@@ -26,6 +27,8 @@ REASON_MAPPING: dict[str, list[str]] = {
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
@@ -145,8 +148,7 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
await user_credit_model.onboarding_reward(user_id, reward, step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
@@ -276,14 +278,8 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
for word in user_onboarding.integrations
|
||||
]
|
||||
|
||||
where_clause["is_available"] = True
|
||||
|
||||
# Try to take only agents that are available and allowed for onboarding
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where={
|
||||
"is_available": True,
|
||||
"useForOnboarding": True,
|
||||
},
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
{"runs": "desc"},
|
||||
@@ -292,16 +288,59 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
take=100,
|
||||
)
|
||||
|
||||
# If not enough agents found, relax the useForOnboarding filter
|
||||
agentListings = await prisma.models.StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": [agent.storeListingVersionId for agent in storeAgents]},
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
|
||||
for listing in agentListings:
|
||||
agent = listing.AgentGraph
|
||||
if agent is None:
|
||||
continue
|
||||
graph = GraphModel.from_db(agent)
|
||||
# Remove agents with empty input schema
|
||||
if not graph.input_schema:
|
||||
storeAgents = [
|
||||
a for a in storeAgents if a.storeListingVersionId != listing.id
|
||||
]
|
||||
continue
|
||||
|
||||
# Remove agents with empty credentials
|
||||
# Get nodes from this agent that have credentials
|
||||
nodes = await prisma.models.AgentNode.prisma().find_many(
|
||||
where={
|
||||
"agentGraphId": agent.id,
|
||||
"agentBlockId": {"in": list(CREDENTIALS_FIELDS.keys())},
|
||||
},
|
||||
)
|
||||
for node in nodes:
|
||||
block_id = node.agentBlockId
|
||||
field_name = CREDENTIALS_FIELDS[block_id]
|
||||
# If there are no credentials or they are empty, remove the agent
|
||||
# FIXME ignores default values
|
||||
if (
|
||||
field_name not in node.constantInput
|
||||
or node.constantInput[field_name] is None
|
||||
):
|
||||
storeAgents = [
|
||||
a for a in storeAgents if a.storeListingVersionId != listing.id
|
||||
]
|
||||
break
|
||||
|
||||
# If there are less than 2 agents, add more agents to the list
|
||||
if len(storeAgents) < 2:
|
||||
storeAgents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
storeAgents += await prisma.models.StoreAgent.prisma().find_many(
|
||||
where={
|
||||
"listing_id": {"not_in": [agent.listing_id for agent in storeAgents]},
|
||||
},
|
||||
order=[
|
||||
{"featured": "desc"},
|
||||
{"runs": "desc"},
|
||||
{"rating": "desc"},
|
||||
],
|
||||
take=100,
|
||||
take=2 - len(storeAgents),
|
||||
)
|
||||
|
||||
# Calculate points for the first X agents and choose the top 2
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
import prisma.models
|
||||
|
||||
|
||||
class StoreAgentWithRank(prisma.models.StoreAgent):
|
||||
rank: float
|
||||
@@ -1,29 +1,24 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
settings = Settings()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring connection")
|
||||
def connect() -> Redis:
|
||||
def connect(decode_responses: bool = True) -> Redis:
|
||||
c = Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
c.ping()
|
||||
return c
|
||||
@@ -42,9 +37,9 @@ def get_redis() -> Redis:
|
||||
@conn_retry("AsyncRedis", "Acquiring connection")
|
||||
async def connect_async() -> AsyncRedis:
|
||||
c = AsyncRedis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=True,
|
||||
)
|
||||
await c.ping()
|
||||
|
||||
@@ -15,9 +15,9 @@ from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.cache import cached
|
||||
from backend.util.encryption import JSONCryptor
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -354,36 +354,6 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
) from e
|
||||
|
||||
|
||||
async def disable_all_user_notifications(user_id: str) -> None:
|
||||
"""Disable all notification preferences for a user.
|
||||
|
||||
Used when user's email bounces/is inactive to prevent any future notifications.
|
||||
"""
|
||||
try:
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={
|
||||
"notifyOnAgentRun": False,
|
||||
"notifyOnZeroBalance": False,
|
||||
"notifyOnLowBalance": False,
|
||||
"notifyOnBlockExecutionFailed": False,
|
||||
"notifyOnContinuousAgentError": False,
|
||||
"notifyOnDailySummary": False,
|
||||
"notifyOnWeeklySummary": False,
|
||||
"notifyOnMonthlySummary": False,
|
||||
"notifyOnAgentApproved": False,
|
||||
"notifyOnAgentRejected": False,
|
||||
},
|
||||
)
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
logger.info(f"Disabled all notification preferences for user {user_id}")
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to disable notifications for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_email_verification(user_id: str) -> bool:
|
||||
"""Get the email verification status for a user."""
|
||||
try:
|
||||
|
||||
@@ -22,13 +22,15 @@ logger = logging.getLogger(__name__)
|
||||
@pytest.fixture
|
||||
def redis_client():
|
||||
"""Get Redis client for testing using same config as backend."""
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Use same config as backend but without decode_responses since ClusterLock needs raw bytes
|
||||
client = redis.Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=False, # ClusterLock needs raw bytes for ownership verification
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
@@ -10,7 +9,6 @@ from backend.data.execution import (
|
||||
get_execution_kv_data,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_graph_executions_count,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
@@ -30,17 +28,14 @@ from backend.data.graph import (
|
||||
get_node,
|
||||
)
|
||||
from backend.data.notifications import (
|
||||
clear_all_user_notification_batches,
|
||||
create_or_add_to_user_notification_batch,
|
||||
empty_user_notification_batch,
|
||||
get_all_batches_by_type,
|
||||
get_user_notification_batch,
|
||||
get_user_notification_oldest_message_in_batch,
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
@@ -58,10 +53,8 @@ from backend.util.service import (
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -70,27 +63,24 @@ R = TypeVar("R")
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_credits(user_id)
|
||||
return await _user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: "FastAPI"):
|
||||
async with super().lifespan(app):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
await db.connect()
|
||||
|
||||
logger.info(f"[{self.service_name}] ✅ Ready")
|
||||
yield
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
super().run_service()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
await db.disconnect()
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
|
||||
async def health_check(self) -> str:
|
||||
if not db.is_connected():
|
||||
@@ -121,7 +111,6 @@ class DatabaseManager(AppService):
|
||||
|
||||
# Executions
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_executions_count = _(get_graph_executions_count)
|
||||
get_graph_execution_meta = _(get_graph_execution_meta)
|
||||
create_graph_execution = _(create_graph_execution)
|
||||
get_node_execution = _(get_node_execution)
|
||||
@@ -153,18 +142,15 @@ class DatabaseManager(AppService):
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Notifications - async
|
||||
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = _(empty_user_notification_batch)
|
||||
remove_notifications_from_batch = _(remove_notifications_from_batch)
|
||||
get_all_batches_by_type = _(get_all_batches_by_type)
|
||||
get_user_notification_batch = _(get_user_notification_batch)
|
||||
get_user_notification_oldest_message_in_batch = _(
|
||||
@@ -193,7 +179,6 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
|
||||
# Executions
|
||||
get_graph_executions = _(d.get_graph_executions)
|
||||
get_graph_executions_count = _(d.get_graph_executions_count)
|
||||
get_graph_execution_meta = _(d.get_graph_execution_meta)
|
||||
get_node_executions = _(d.get_node_executions)
|
||||
update_node_execution_status = _(d.update_node_execution_status)
|
||||
@@ -239,7 +224,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
@@ -257,12 +241,10 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
# Notifications
|
||||
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
||||
create_or_add_to_user_notification_batch = (
|
||||
d.create_or_add_to_user_notification_batch
|
||||
)
|
||||
empty_user_notification_batch = d.empty_user_notification_batch
|
||||
remove_notifications_from_batch = d.remove_notifications_from_batch
|
||||
get_all_batches_by_type = d.get_all_batches_by_type
|
||||
get_user_notification_batch = d.get_user_notification_batch
|
||||
get_user_notification_oldest_message_in_batch = (
|
||||
|
||||
@@ -7,10 +7,8 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
import sentry_sdk
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
@@ -86,11 +84,7 @@ from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import (
|
||||
continuous_retry,
|
||||
func_retry,
|
||||
send_rate_limited_discord_alert,
|
||||
)
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
@@ -190,7 +184,6 @@ async def execute_node(
|
||||
_input_data.inputs = input_data
|
||||
if nodes_input_masks:
|
||||
_input_data.nodes_input_masks = nodes_input_masks
|
||||
_input_data.user_id = user_id
|
||||
input_data = _input_data.model_dump()
|
||||
data.inputs = input_data
|
||||
|
||||
@@ -225,37 +218,14 @@ async def execute_node(
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
output_size = 0
|
||||
|
||||
# sentry tracking nonsense to get user counts for blocks because isolation scopes don't work :(
|
||||
scope = sentry_sdk.get_current_scope()
|
||||
|
||||
# save the tags
|
||||
original_user = scope._user
|
||||
original_tags = dict(scope._tags) if scope._tags else {}
|
||||
# Set user ID for error tracking
|
||||
scope.set_user({"id": user_id})
|
||||
|
||||
scope.set_tag("graph_id", graph_id)
|
||||
scope.set_tag("node_id", node_id)
|
||||
scope.set_tag("block_name", node_block.name)
|
||||
scope.set_tag("block_id", node_block.id)
|
||||
for k, v in (data.user_context or UserContext(timezone="UTC")).model_dump().items():
|
||||
scope.set_tag(f"user_context.{k}", v)
|
||||
|
||||
try:
|
||||
async for output_name, output_data in node_block.execute(
|
||||
input_data, **extra_exec_kwargs
|
||||
):
|
||||
output_data = json.to_dict(output_data)
|
||||
output_data = json.convert_pydantic_to_json(output_data)
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||||
yield output_name, output_data
|
||||
except Exception:
|
||||
# Capture exception WITH context still set before restoring scope
|
||||
sentry_sdk.capture_exception(scope=scope)
|
||||
sentry_sdk.flush() # Ensure it's sent before we restore scope
|
||||
# Re-raise to maintain normal error flow
|
||||
raise
|
||||
finally:
|
||||
# Ensure credentials are released even if execution fails
|
||||
if creds_lock and (await creds_lock.locked()) and (await creds_lock.owned()):
|
||||
@@ -270,10 +240,6 @@ async def execute_node(
|
||||
execution_stats.input_size = input_size
|
||||
execution_stats.output_size = output_size
|
||||
|
||||
# Restore scope AFTER error has been captured
|
||||
scope._user = original_user
|
||||
scope._tags = original_tags
|
||||
|
||||
|
||||
async def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
@@ -598,6 +564,7 @@ class ExecutionProcessor:
|
||||
await persist_output(
|
||||
"error", str(stats.error) or type(stats.error).__name__
|
||||
)
|
||||
|
||||
return status
|
||||
|
||||
@func_retry
|
||||
@@ -1012,31 +979,16 @@ class ExecutionProcessor:
|
||||
if isinstance(e, Exception)
|
||||
else Exception(f"{e.__class__.__name__}: {e}")
|
||||
)
|
||||
if not execution_stats.error:
|
||||
execution_stats.error = str(error)
|
||||
|
||||
known_errors = (InsufficientBalanceError, ModerationError)
|
||||
if isinstance(error, known_errors):
|
||||
execution_stats.error = str(error)
|
||||
return ExecutionStatus.FAILED
|
||||
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
log_metadata.exception(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
|
||||
# Send rate-limited Discord alert for unknown/unexpected errors
|
||||
send_rate_limited_discord_alert(
|
||||
"graph_execution",
|
||||
error,
|
||||
"unknown_error",
|
||||
f"🚨 **Unknown Graph Execution Error**\n"
|
||||
f"User: {graph_exec.user_id}\n"
|
||||
f"Graph ID: {graph_exec.graph_id}\n"
|
||||
f"Execution ID: {graph_exec.graph_exec_id}\n"
|
||||
f"Error Type: {type(error).__name__}\n"
|
||||
f"Error: {str(error)[:200]}{'...' if len(str(error)) > 200 else ''}\n",
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
@@ -1211,9 +1163,9 @@ class ExecutionProcessor:
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance / 100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
|
||||
f"Current balance: ${e.balance/100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount)/100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall)/100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
@@ -1260,9 +1212,9 @@ class ExecutionProcessor:
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
|
||||
f"Current balance: ${current_balance / 100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD/100:.2f}\n"
|
||||
f"Current balance: ${current_balance/100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost/100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
@@ -1493,39 +1445,10 @@ class ExecutionManager(AppProcess):
|
||||
return
|
||||
|
||||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||||
user_id = graph_exec_entry.user_id
|
||||
graph_id = graph_exec_entry.graph_id
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}, user_id={user_id}"
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
)
|
||||
|
||||
# Check user rate limit before processing
|
||||
try:
|
||||
# Only check executions from the last 24 hours for performance
|
||||
current_running_count = get_db_client().get_graph_executions_count(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
statuses=[ExecutionStatus.RUNNING],
|
||||
created_time_gte=datetime.now(timezone.utc) - timedelta(hours=24),
|
||||
)
|
||||
|
||||
if (
|
||||
current_running_count
|
||||
>= settings.config.max_concurrent_graph_executions_per_user
|
||||
):
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Rate limit exceeded for user {user_id} on graph {graph_id}: "
|
||||
f"{current_running_count}/{settings.config.max_concurrent_graph_executions_per_user} running executions"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Failed to check rate limit for user {user_id}: {e}, proceeding with execution"
|
||||
)
|
||||
# If rate limit check fails, proceed to avoid blocking executions
|
||||
|
||||
# Check for local duplicate execution first
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
logger.warning(
|
||||
@@ -1548,12 +1471,11 @@ class ExecutionManager(AppProcess):
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
|
||||
)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
self._execution_locks[graph_exec_id] = cluster_lock
|
||||
|
||||
@@ -1714,8 +1636,6 @@ class ExecutionManager(AppProcess):
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
@@ -248,7 +248,7 @@ class Scheduler(AppService):
|
||||
raise UnhealthyServiceError("Scheduler is still initializing")
|
||||
|
||||
# Check if we're in the middle of cleanup
|
||||
if self._shutting_down:
|
||||
if self.cleaned_up:
|
||||
return await super().health_check()
|
||||
|
||||
# Normal operation - check if scheduler is running
|
||||
@@ -375,6 +375,7 @@ class Scheduler(AppService):
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
if self.scheduler:
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
self.scheduler.shutdown(wait=True)
|
||||
@@ -389,7 +390,7 @@ class Scheduler(AppService):
|
||||
logger.info("⏳ Waiting for event loop thread to finish...")
|
||||
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
super().cleanup()
|
||||
logger.info("Scheduler cleanup complete.")
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
|
||||
@@ -34,7 +34,6 @@ from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.util.cache import cached
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -42,12 +41,11 @@ from backend.util.clients import (
|
||||
get_integration_credentials_store,
|
||||
)
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=3600)
|
||||
async def get_user_context(user_id: str) -> UserContext:
|
||||
"""
|
||||
Get UserContext for a user, always returns a valid context with timezone.
|
||||
@@ -55,11 +53,7 @@ async def get_user_context(user_id: str) -> UserContext:
|
||||
"""
|
||||
user_context = UserContext(timezone="UTC") # Default to UTC
|
||||
try:
|
||||
if prisma.is_connected():
|
||||
user = await get_user_by_id(user_id)
|
||||
else:
|
||||
user = await get_database_manager_async_client().get_user_by_id(user_id)
|
||||
|
||||
user = await get_user_by_id(user_id)
|
||||
if user and user.timezone and user.timezone != "not-set":
|
||||
user_context.timezone = user.timezone
|
||||
logger.debug(f"Retrieved user context: timezone={user.timezone}")
|
||||
@@ -99,11 +93,7 @@ class LogMetadata(TruncatedLogger):
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
prefix = (
|
||||
"[ExecutionManager]"
|
||||
if is_structured_logging_enabled()
|
||||
else f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]" # noqa
|
||||
)
|
||||
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
|
||||
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
@cached(ttl_seconds=3600)
|
||||
@cached(ttl_seconds=3600) # Cache webhook managers for 1 hour
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
webhook_managers = {}
|
||||
|
||||
|
||||
@@ -25,11 +25,7 @@ from backend.data.notifications import (
|
||||
get_summary_params_type,
|
||||
)
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import (
|
||||
disable_all_user_notifications,
|
||||
generate_unsubscribe_link,
|
||||
set_user_email_verification,
|
||||
)
|
||||
from backend.data.user import generate_unsubscribe_link
|
||||
from backend.notifications.email import EmailSender
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.logging import TruncatedLogger
|
||||
@@ -42,7 +38,7 @@ from backend.util.service import (
|
||||
endpoint_to_sync,
|
||||
expose,
|
||||
)
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[NotificationManager]")
|
||||
settings = Settings()
|
||||
@@ -128,12 +124,6 @@ def get_routing_key(event_type: NotificationType) -> str:
|
||||
|
||||
def queue_notification(event: NotificationEventModel) -> NotificationResult:
|
||||
"""Queue a notification - exposed method for other services to call"""
|
||||
# Disable in production
|
||||
if settings.config.app_env == AppEnvironment.PRODUCTION:
|
||||
return NotificationResult(
|
||||
success=True,
|
||||
message="Queueing notifications is disabled in production",
|
||||
)
|
||||
try:
|
||||
logger.debug(f"Received Request to queue {event=}")
|
||||
|
||||
@@ -161,12 +151,6 @@ def queue_notification(event: NotificationEventModel) -> NotificationResult:
|
||||
|
||||
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult:
|
||||
"""Queue a notification - exposed method for other services to call"""
|
||||
# Disable in production
|
||||
if settings.config.app_env == AppEnvironment.PRODUCTION:
|
||||
return NotificationResult(
|
||||
success=True,
|
||||
message="Queueing notifications is disabled in production",
|
||||
)
|
||||
try:
|
||||
logger.debug(f"Received Request to queue {event=}")
|
||||
|
||||
@@ -229,9 +213,6 @@ class NotificationManager(AppService):
|
||||
|
||||
@expose
|
||||
async def queue_weekly_summary(self):
|
||||
# disable in prod
|
||||
if settings.config.app_env == AppEnvironment.PRODUCTION:
|
||||
return
|
||||
# Use the existing event loop instead of creating a new one with asyncio.run()
|
||||
asyncio.create_task(self._queue_weekly_summary())
|
||||
|
||||
@@ -245,9 +226,7 @@ class NotificationManager(AppService):
|
||||
logger.info(
|
||||
f"Querying for active users between {start_time} and {current_time}"
|
||||
)
|
||||
users = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_active_user_ids_in_timerange(
|
||||
users = await get_database_manager_async_client().get_active_user_ids_in_timerange(
|
||||
end_time=current_time.isoformat(),
|
||||
start_time=start_time.isoformat(),
|
||||
)
|
||||
@@ -274,9 +253,6 @@ class NotificationManager(AppService):
|
||||
async def process_existing_batches(
|
||||
self, notification_types: list[NotificationType]
|
||||
):
|
||||
# disable in prod
|
||||
if settings.config.app_env == AppEnvironment.PRODUCTION:
|
||||
return
|
||||
# Use the existing event loop instead of creating a new process
|
||||
asyncio.create_task(self._process_existing_batches(notification_types))
|
||||
|
||||
@@ -290,15 +266,15 @@ class NotificationManager(AppService):
|
||||
|
||||
for notification_type in notification_types:
|
||||
# Get all batches for this notification type
|
||||
batches = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_all_batches_by_type(notification_type)
|
||||
batches = (
|
||||
await get_database_manager_async_client().get_all_batches_by_type(
|
||||
notification_type
|
||||
)
|
||||
)
|
||||
|
||||
for batch in batches:
|
||||
# Check if batch has aged out
|
||||
oldest_message = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_oldest_message_in_batch(
|
||||
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -313,9 +289,9 @@ class NotificationManager(AppService):
|
||||
|
||||
# If batch has aged out, process it
|
||||
if oldest_message.created_at + max_delay < current_time:
|
||||
recipient_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_by_id(batch.user_id)
|
||||
recipient_email = await get_database_manager_async_client().get_user_email_by_id(
|
||||
batch.user_id
|
||||
)
|
||||
|
||||
if not recipient_email:
|
||||
logger.error(
|
||||
@@ -332,25 +308,21 @@ class NotificationManager(AppService):
|
||||
f"User {batch.user_id} does not want to receive {notification_type} notifications"
|
||||
)
|
||||
# Clear the batch
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).empty_user_notification_batch(
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_batch(batch.user_id, notification_type)
|
||||
batch_data = await get_database_manager_async_client().get_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
if not batch_data or not batch_data.notifications:
|
||||
logger.error(
|
||||
f"Batch data not found for user {batch.user_id}"
|
||||
)
|
||||
# Clear the batch
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).empty_user_notification_batch(
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
continue
|
||||
@@ -386,9 +358,7 @@ class NotificationManager(AppService):
|
||||
)
|
||||
|
||||
# Clear the batch
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).empty_user_notification_batch(
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
batch.user_id, notification_type
|
||||
)
|
||||
|
||||
@@ -443,13 +413,15 @@ class NotificationManager(AppService):
|
||||
self, user_id: str, event_type: NotificationType
|
||||
) -> bool:
|
||||
"""Check if a user wants to receive a notification based on their preferences and email verification status"""
|
||||
validated_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_verification(user_id)
|
||||
validated_email = (
|
||||
await get_database_manager_async_client().get_user_email_verification(
|
||||
user_id
|
||||
)
|
||||
)
|
||||
preference = (
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_preference(user_id)
|
||||
await get_database_manager_async_client().get_user_notification_preference(
|
||||
user_id
|
||||
)
|
||||
).preferences.get(event_type, True)
|
||||
# only if both are true, should we email this person
|
||||
return validated_email and preference
|
||||
@@ -465,9 +437,7 @@ class NotificationManager(AppService):
|
||||
|
||||
try:
|
||||
# Get summary data from the database
|
||||
summary_data = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_execution_summary_data(
|
||||
summary_data = await get_database_manager_async_client().get_user_execution_summary_data(
|
||||
user_id=user_id,
|
||||
start_time=params.start_date,
|
||||
end_time=params.end_date,
|
||||
@@ -554,13 +524,13 @@ class NotificationManager(AppService):
|
||||
self, user_id: str, event_type: NotificationType, event: NotificationEventModel
|
||||
) -> bool:
|
||||
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).create_or_add_to_user_notification_batch(user_id, event_type, event)
|
||||
await get_database_manager_async_client().create_or_add_to_user_notification_batch(
|
||||
user_id, event_type, event
|
||||
)
|
||||
|
||||
oldest_message = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_oldest_message_in_batch(user_id, event_type)
|
||||
oldest_message = await get_database_manager_async_client().get_user_notification_oldest_message_in_batch(
|
||||
user_id, event_type
|
||||
)
|
||||
if not oldest_message:
|
||||
logger.error(
|
||||
f"Batch for user {user_id} and type {event_type} has no oldest message whichshould never happen!!!!!!!!!!!!!!!!"
|
||||
@@ -610,9 +580,11 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.debug(f"Processing immediate notification: {event}")
|
||||
|
||||
recipient_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_by_id(event.user_id)
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -647,9 +619,11 @@ class NotificationManager(AppService):
|
||||
return False
|
||||
logger.info(f"Processing batch notification: {event}")
|
||||
|
||||
recipient_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_by_id(event.user_id)
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -668,9 +642,11 @@ class NotificationManager(AppService):
|
||||
if not should_send:
|
||||
logger.info("Batch not old enough to send")
|
||||
return False
|
||||
batch = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_batch(event.user_id, event.type)
|
||||
batch = (
|
||||
await get_database_manager_async_client().get_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
)
|
||||
if not batch or not batch.notifications:
|
||||
logger.error(f"Batch not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -681,7 +657,6 @@ class NotificationManager(AppService):
|
||||
get_notif_data_type(db_event.type)
|
||||
].model_validate(
|
||||
{
|
||||
"id": db_event.id, # Include ID from database
|
||||
"user_id": event.user_id,
|
||||
"type": db_event.type,
|
||||
"data": db_event.data,
|
||||
@@ -704,9 +679,6 @@ class NotificationManager(AppService):
|
||||
chunk_sent = False
|
||||
for attempt_size in [chunk_size, 50, 25, 10, 5, 1]:
|
||||
chunk = batch_messages[i : i + attempt_size]
|
||||
chunk_ids = [
|
||||
msg.id for msg in chunk if msg.id
|
||||
] # Extract IDs for removal
|
||||
|
||||
try:
|
||||
# Try to render the email to check its size
|
||||
@@ -733,23 +705,6 @@ class NotificationManager(AppService):
|
||||
user_unsub_link=unsub_link,
|
||||
)
|
||||
|
||||
# Remove successfully sent notifications immediately
|
||||
if chunk_ids:
|
||||
try:
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).remove_notifications_from_batch(
|
||||
event.user_id, event.type, chunk_ids
|
||||
)
|
||||
logger.info(
|
||||
f"Removed {len(chunk_ids)} sent notifications from batch"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to remove sent notifications: {e}"
|
||||
)
|
||||
# Continue anyway - better to risk duplicates than lose emails
|
||||
|
||||
# Track successful sends
|
||||
successfully_sent_count += len(chunk)
|
||||
|
||||
@@ -767,137 +722,13 @@ class NotificationManager(AppService):
|
||||
i += len(chunk)
|
||||
chunk_sent = True
|
||||
break
|
||||
else:
|
||||
# Message is too large even after size reduction
|
||||
if attempt_size == 1:
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Single notification exceeds email size limit "
|
||||
f"({len(test_message):,} chars > {MAX_EMAIL_SIZE:,} chars). "
|
||||
f"Removing permanently from batch - will not retry."
|
||||
)
|
||||
|
||||
# Remove the oversized notification permanently - it will NEVER fit
|
||||
if chunk_ids:
|
||||
try:
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).remove_notifications_from_batch(
|
||||
event.user_id, event.type, chunk_ids
|
||||
)
|
||||
logger.info(
|
||||
f"Removed oversized notification {chunk_ids[0]} from batch permanently"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to remove oversized notification: {e}"
|
||||
)
|
||||
|
||||
failed_indices.append(i)
|
||||
i += 1
|
||||
chunk_sent = True
|
||||
break
|
||||
# Try smaller chunk size
|
||||
continue
|
||||
except Exception as e:
|
||||
# Check if it's a Postmark API error
|
||||
if attempt_size == 1:
|
||||
# Single notification failed - determine the actual cause
|
||||
error_message = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
|
||||
# Check for HTTP 406 - Inactive recipient (common in Postmark errors)
|
||||
if "406" in error_message or "inactive" in error_message:
|
||||
logger.warning(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Recipient marked as inactive by Postmark. "
|
||||
f"Error: {e}. Disabling ALL notifications for this user."
|
||||
)
|
||||
|
||||
# 1. Mark email as unverified
|
||||
try:
|
||||
await set_user_email_verification(
|
||||
event.user_id, False
|
||||
)
|
||||
logger.info(
|
||||
f"Set email verification to false for user {event.user_id}"
|
||||
)
|
||||
except Exception as deactivation_error:
|
||||
logger.error(
|
||||
f"Failed to deactivate email for user {event.user_id}: "
|
||||
f"{deactivation_error}"
|
||||
)
|
||||
|
||||
# 2. Disable all notification preferences
|
||||
try:
|
||||
await disable_all_user_notifications(event.user_id)
|
||||
logger.info(
|
||||
f"Disabled all notification preferences for user {event.user_id}"
|
||||
)
|
||||
except Exception as disable_error:
|
||||
logger.error(
|
||||
f"Failed to disable notification preferences: {disable_error}"
|
||||
)
|
||||
|
||||
# 3. Clear ALL notification batches for this user
|
||||
try:
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).clear_all_user_notification_batches(event.user_id)
|
||||
logger.info(
|
||||
f"Cleared ALL notification batches for user {event.user_id}"
|
||||
)
|
||||
except Exception as remove_error:
|
||||
logger.error(
|
||||
f"Failed to clear batches for inactive recipient: {remove_error}"
|
||||
)
|
||||
|
||||
# Stop processing - we've nuked everything for this user
|
||||
return True
|
||||
# Check for HTTP 422 - Malformed data
|
||||
elif (
|
||||
"422" in error_message
|
||||
or "unprocessable" in error_message
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Malformed notification data rejected by Postmark. "
|
||||
f"Error: {e}. Removing from batch permanently."
|
||||
)
|
||||
|
||||
# Remove from batch - 422 means bad data that won't fix itself
|
||||
if chunk_ids:
|
||||
try:
|
||||
await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).remove_notifications_from_batch(
|
||||
event.user_id, event.type, chunk_ids
|
||||
)
|
||||
logger.info(
|
||||
"Removed malformed notification from batch permanently"
|
||||
)
|
||||
except Exception as remove_error:
|
||||
logger.error(
|
||||
f"Failed to remove malformed notification: {remove_error}"
|
||||
)
|
||||
# Check if it's a ValueError for size limit
|
||||
elif (
|
||||
isinstance(e, ValueError)
|
||||
and "too large" in error_message
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Notification size exceeds email limit. "
|
||||
f"Error: {e}. Skipping this notification."
|
||||
)
|
||||
# Other API errors
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to send notification at index {i}: "
|
||||
f"Email API error ({error_type}): {e}. "
|
||||
f"Skipping this notification."
|
||||
)
|
||||
|
||||
# Even single notification is too large
|
||||
logger.error(
|
||||
f"Single notification too large to send: {e}. "
|
||||
f"Skipping notification at index {i}"
|
||||
)
|
||||
failed_indices.append(i)
|
||||
i += 1
|
||||
chunk_sent = True
|
||||
@@ -911,20 +742,18 @@ class NotificationManager(AppService):
|
||||
failed_indices.append(i)
|
||||
i += 1
|
||||
|
||||
# Check what remains in the batch (notifications are removed as sent)
|
||||
remaining_batch = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_notification_batch(event.user_id, event.type)
|
||||
|
||||
if not remaining_batch or not remaining_batch.notifications:
|
||||
# Only empty the batch if ALL notifications were sent successfully
|
||||
if successfully_sent_count == len(batch_messages):
|
||||
logger.info(
|
||||
f"All {successfully_sent_count} notifications sent and removed from batch"
|
||||
f"Successfully sent all {successfully_sent_count} notifications, clearing batch"
|
||||
)
|
||||
await get_database_manager_async_client().empty_user_notification_batch(
|
||||
event.user_id, event.type
|
||||
)
|
||||
else:
|
||||
remaining_count = len(remaining_batch.notifications)
|
||||
logger.warning(
|
||||
f"Sent {successfully_sent_count} notifications. "
|
||||
f"{remaining_count} remain in batch for retry due to errors."
|
||||
f"Only sent {successfully_sent_count} of {len(batch_messages)} notifications. "
|
||||
f"Failed indices: {failed_indices}. Batch will be retained for retry."
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -942,9 +771,11 @@ class NotificationManager(AppService):
|
||||
|
||||
logger.info(f"Processing summary notification: {model}")
|
||||
|
||||
recipient_email = await get_database_manager_async_client(
|
||||
should_retry=False
|
||||
).get_user_email_by_id(event.user_id)
|
||||
recipient_email = (
|
||||
await get_database_manager_async_client().get_user_email_by_id(
|
||||
event.user_id
|
||||
)
|
||||
)
|
||||
if not recipient_email:
|
||||
logger.error(f"User email not found for user {event.user_id}")
|
||||
return False
|
||||
@@ -1017,14 +848,10 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Fatal error in consumer for {queue_name}: {e}")
|
||||
raise
|
||||
|
||||
def run_service(self):
|
||||
# Queue the main _run_service task
|
||||
asyncio.run_coroutine_threadsafe(self._run_service(), self.shared_event_loop)
|
||||
|
||||
# Start the main event loop
|
||||
super().run_service()
|
||||
|
||||
@continuous_retry()
|
||||
def run_service(self):
|
||||
self.run_and_wait(self._run_service())
|
||||
|
||||
async def _run_service(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
@@ -1090,10 +917,9 @@ class NotificationManager(AppService):
|
||||
def cleanup(self):
|
||||
"""Cleanup service resources"""
|
||||
self.running = False
|
||||
logger.info("⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
|
||||
class NotificationManagerClient(AppServiceClient):
|
||||
|
||||
@@ -1,598 +0,0 @@
|
||||
"""Tests for notification error handling in NotificationManager."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import AgentRunData, NotificationEventModel
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
|
||||
class TestNotificationErrorHandling:
|
||||
"""Test cases for notification error handling in NotificationManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def notification_manager(self):
|
||||
"""Create a NotificationManager instance for testing."""
|
||||
with patch("backend.notifications.notifications.AppService.__init__"):
|
||||
manager = NotificationManager()
|
||||
manager.email_sender = MagicMock()
|
||||
# Mock the _get_template method used by _process_batch
|
||||
template_mock = Mock()
|
||||
template_mock.base_template = "base"
|
||||
template_mock.subject_template = "subject"
|
||||
template_mock.body_template = "body"
|
||||
manager.email_sender._get_template = Mock(return_value=template_mock)
|
||||
# Mock the formatter
|
||||
manager.email_sender.formatter = Mock()
|
||||
manager.email_sender.formatter.format_email = Mock(
|
||||
return_value=("subject", "body content")
|
||||
)
|
||||
manager.email_sender.formatter.env = Mock()
|
||||
manager.email_sender.formatter.env.globals = {
|
||||
"base_url": "http://example.com"
|
||||
}
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_event(self):
|
||||
"""Create a sample batch event for testing."""
|
||||
return NotificationEventModel(
|
||||
type=NotificationType.AGENT_RUN,
|
||||
user_id="user_1",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
data=AgentRunData(
|
||||
agent_name="Test Agent",
|
||||
credits_used=10.0,
|
||||
execution_time=5.0,
|
||||
node_count=3,
|
||||
graph_id="graph_1",
|
||||
outputs=[],
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_notifications(self):
|
||||
"""Create sample batch notifications for testing."""
|
||||
notifications = []
|
||||
for i in range(3):
|
||||
notification = Mock()
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
return notifications
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_406_stops_all_processing_for_user(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test that 406 inactive recipient error stops ALL processing for that user."""
|
||||
with patch("backend.notifications.notifications.logger"), patch(
|
||||
"backend.notifications.notifications.set_user_email_verification",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_set_verification, patch(
|
||||
"backend.notifications.notifications.disable_all_user_notifications",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_disable_all, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
return_value=Mock(notifications=notifications)
|
||||
)
|
||||
mock_db.clear_all_user_notification_batches = AsyncMock()
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Track calls
|
||||
call_count = [0]
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
current_call = call_count[0]
|
||||
call_count[0] += 1
|
||||
|
||||
# First two succeed, third hits 406
|
||||
if current_call < 2:
|
||||
return None
|
||||
else:
|
||||
raise Exception("Recipient marked as inactive (406)")
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Only 3 calls should have been made (2 successful, 1 failed with 406)
|
||||
assert call_count[0] == 3
|
||||
|
||||
# User should be deactivated
|
||||
mock_set_verification.assert_called_once_with("user_1", False)
|
||||
mock_disable_all.assert_called_once_with("user_1")
|
||||
mock_db.clear_all_user_notification_batches.assert_called_once_with(
|
||||
"user_1"
|
||||
)
|
||||
|
||||
# No further processing should occur after 406
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_422_permanently_removes_malformed_notification(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test that 422 error permanently removes the malformed notification from batch and continues with others."""
|
||||
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
side_effect=[
|
||||
Mock(notifications=notifications),
|
||||
Mock(notifications=[]), # Empty after processing
|
||||
]
|
||||
)
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Track calls
|
||||
call_count = [0]
|
||||
successful_indices = []
|
||||
removed_notification_ids = []
|
||||
|
||||
# Capture what gets removed
|
||||
def remove_side_effect(user_id, notif_type, notif_ids):
|
||||
removed_notification_ids.extend(notif_ids)
|
||||
return None
|
||||
|
||||
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
current_call = call_count[0]
|
||||
call_count[0] += 1
|
||||
|
||||
# Index 2 has malformed data (422)
|
||||
if current_call == 2:
|
||||
raise Exception(
|
||||
"Unprocessable entity (422): Malformed email data"
|
||||
)
|
||||
else:
|
||||
successful_indices.append(current_call)
|
||||
return None
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert call_count[0] == 5 # All 5 attempted
|
||||
assert len(successful_indices) == 4 # 4 succeeded (all except index 2)
|
||||
assert 2 not in successful_indices # Index 2 failed
|
||||
|
||||
# Verify 422 error was logged
|
||||
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
||||
assert any(
|
||||
"422" in call or "malformed" in call.lower() for call in error_calls
|
||||
)
|
||||
|
||||
# Verify all notifications were removed (4 successful + 1 malformed)
|
||||
assert mock_db.remove_notifications_from_batch.call_count == 5
|
||||
assert (
|
||||
"notif_2" in removed_notification_ids
|
||||
) # Malformed one was removed permanently
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_notification_permanently_removed(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test that oversized notifications are permanently removed from batch but others continue."""
|
||||
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
side_effect=[
|
||||
Mock(notifications=notifications),
|
||||
Mock(notifications=[]), # Empty after processing
|
||||
]
|
||||
)
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Override formatter to simulate oversized on index 3
|
||||
# original_format = notification_manager.email_sender.formatter.format_email
|
||||
|
||||
def format_side_effect(*args, **kwargs):
|
||||
# Check if we're formatting index 3
|
||||
data = kwargs.get("data", {}).get("notifications", [])
|
||||
if data and len(data) == 1:
|
||||
# Check notification content to identify index 3
|
||||
if any(
|
||||
"Test Agent 3" in str(n.data)
|
||||
for n in data
|
||||
if hasattr(n, "data")
|
||||
):
|
||||
# Return oversized message for index 3
|
||||
return ("subject", "x" * 5_000_000) # Over 4.5MB limit
|
||||
return ("subject", "normal sized content")
|
||||
|
||||
notification_manager.email_sender.formatter.format_email = Mock(
|
||||
side_effect=format_side_effect
|
||||
)
|
||||
|
||||
# Track calls
|
||||
successful_indices = []
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
# Track which notification was sent based on content
|
||||
for i, notif in enumerate(notifications):
|
||||
if any(
|
||||
f"Test Agent {i}" in str(n.data)
|
||||
for n in data
|
||||
if hasattr(n, "data")
|
||||
):
|
||||
successful_indices.append(i)
|
||||
return None
|
||||
return None
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert (
|
||||
len(successful_indices) == 4
|
||||
) # Only 4 sent (index 3 skipped due to size)
|
||||
assert 3 not in successful_indices # Index 3 was not sent
|
||||
|
||||
# Verify oversized error was logged
|
||||
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
||||
assert any(
|
||||
"exceeds email size limit" in call or "oversized" in call.lower()
|
||||
for call in error_calls
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generic_api_error_keeps_notification_for_retry(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test that generic API errors keep notifications in batch for retry while others continue."""
|
||||
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Notification that failed with generic error
|
||||
failed_notifications = [notifications[1]] # Only index 1 remains for retry
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
side_effect=[
|
||||
Mock(notifications=notifications),
|
||||
Mock(
|
||||
notifications=failed_notifications
|
||||
), # Failed ones remain for retry
|
||||
]
|
||||
)
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Track calls
|
||||
successful_indices = []
|
||||
failed_indices = []
|
||||
removed_notification_ids = []
|
||||
|
||||
# Capture what gets removed
|
||||
def remove_side_effect(user_id, notif_type, notif_ids):
|
||||
removed_notification_ids.extend(notif_ids)
|
||||
return None
|
||||
|
||||
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
# Track which notification based on content
|
||||
for i, notif in enumerate(notifications):
|
||||
if any(
|
||||
f"Test Agent {i}" in str(n.data)
|
||||
for n in data
|
||||
if hasattr(n, "data")
|
||||
):
|
||||
# Index 1 has generic API error
|
||||
if i == 1:
|
||||
failed_indices.append(i)
|
||||
raise Exception("Network timeout - temporary failure")
|
||||
else:
|
||||
successful_indices.append(i)
|
||||
return None
|
||||
return None
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert len(successful_indices) == 4 # 4 succeeded (0, 2, 3, 4)
|
||||
assert len(failed_indices) == 1 # 1 failed
|
||||
assert 1 in failed_indices # Index 1 failed
|
||||
|
||||
# Verify generic error was logged
|
||||
error_calls = [call[0][0] for call in mock_logger.error.call_args_list]
|
||||
assert any(
|
||||
"api error" in call.lower() or "skipping" in call.lower()
|
||||
for call in error_calls
|
||||
)
|
||||
|
||||
# Only successful ones should be removed from batch (failed one stays for retry)
|
||||
assert mock_db.remove_notifications_from_batch.call_count == 4
|
||||
assert (
|
||||
"notif_1" not in removed_notification_ids
|
||||
) # Failed one NOT removed (stays for retry)
|
||||
assert "notif_0" in removed_notification_ids # Successful one removed
|
||||
assert "notif_2" in removed_notification_ids # Successful one removed
|
||||
assert "notif_3" in removed_notification_ids # Successful one removed
|
||||
assert "notif_4" in removed_notification_ids # Successful one removed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_all_notifications_sent_successfully(
|
||||
self, notification_manager, sample_batch_event
|
||||
):
|
||||
"""Test successful batch processing where all notifications are sent without errors."""
|
||||
with patch("backend.notifications.notifications.logger") as mock_logger, patch(
|
||||
"backend.notifications.notifications.get_database_manager_async_client"
|
||||
) as mock_db_client, patch(
|
||||
"backend.notifications.notifications.generate_unsubscribe_link"
|
||||
) as mock_unsub_link:
|
||||
|
||||
# Create batch of 5 notifications
|
||||
notifications = []
|
||||
for i in range(5):
|
||||
notification = Mock()
|
||||
notification.id = f"notif_{i}"
|
||||
notification.type = NotificationType.AGENT_RUN
|
||||
notification.data = {
|
||||
"agent_name": f"Test Agent {i}",
|
||||
"credits_used": 10.0 * (i + 1),
|
||||
"execution_time": 5.0 * (i + 1),
|
||||
"node_count": 3 + i,
|
||||
"graph_id": f"graph_{i}",
|
||||
"outputs": [],
|
||||
}
|
||||
notification.created_at = datetime.now(timezone.utc)
|
||||
notifications.append(notification)
|
||||
|
||||
# Setup mocks
|
||||
mock_db = mock_db_client.return_value
|
||||
mock_db.get_user_email_by_id = AsyncMock(return_value="test@example.com")
|
||||
mock_db.get_user_notification_batch = AsyncMock(
|
||||
side_effect=[
|
||||
Mock(notifications=notifications),
|
||||
Mock(notifications=[]), # Empty after all sent successfully
|
||||
]
|
||||
)
|
||||
mock_db.remove_notifications_from_batch = AsyncMock()
|
||||
mock_unsub_link.return_value = "http://example.com/unsub"
|
||||
|
||||
# Mock internal methods
|
||||
notification_manager._should_email_user_based_on_preference = AsyncMock(
|
||||
return_value=True
|
||||
)
|
||||
notification_manager._should_batch = AsyncMock(return_value=True)
|
||||
notification_manager._parse_message = Mock(return_value=sample_batch_event)
|
||||
|
||||
# Track successful sends
|
||||
successful_indices = []
|
||||
removed_notification_ids = []
|
||||
|
||||
# Capture what gets removed
|
||||
def remove_side_effect(user_id, notif_type, notif_ids):
|
||||
removed_notification_ids.extend(notif_ids)
|
||||
return None
|
||||
|
||||
mock_db.remove_notifications_from_batch.side_effect = remove_side_effect
|
||||
|
||||
def send_side_effect(*args, **kwargs):
|
||||
data = kwargs.get("data", [])
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
# Track which notification was sent
|
||||
for i, notif in enumerate(notifications):
|
||||
if any(
|
||||
f"Test Agent {i}" in str(n.data)
|
||||
for n in data
|
||||
if hasattr(n, "data")
|
||||
):
|
||||
successful_indices.append(i)
|
||||
return None
|
||||
return None # Success
|
||||
# Force single processing
|
||||
raise Exception("Force single processing")
|
||||
|
||||
notification_manager.email_sender.send_templated.side_effect = (
|
||||
send_side_effect
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await notification_manager._process_batch(
|
||||
sample_batch_event.model_dump_json()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# All 5 notifications should be sent successfully
|
||||
assert len(successful_indices) == 5
|
||||
assert successful_indices == [0, 1, 2, 3, 4]
|
||||
|
||||
# All notifications should be removed from batch
|
||||
assert mock_db.remove_notifications_from_batch.call_count == 5
|
||||
assert len(removed_notification_ids) == 5
|
||||
for i in range(5):
|
||||
assert f"notif_{i}" in removed_notification_ids
|
||||
|
||||
# No errors should be logged
|
||||
assert mock_logger.error.call_count == 0
|
||||
|
||||
# Info message about successful sends should be logged
|
||||
info_calls = [call[0][0] for call in mock_logger.info.call_args_list]
|
||||
assert any("sent and removed" in call.lower() for call in info_calls)
|
||||
86
autogpt_platform/backend/backend/server/cache_config.py
Normal file
86
autogpt_platform/backend/backend/server/cache_config.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Shared cache configuration constants.
|
||||
|
||||
This module defines all page_size defaults used across the application.
|
||||
By centralizing these values, we ensure that cache invalidation always
|
||||
uses the same page_size as the routes that populate the cache.
|
||||
|
||||
CRITICAL: If you change any of these values, the tests in
|
||||
test_cache_invalidation_consistency.py will fail to remind you to
|
||||
update all dependent code.
|
||||
"""
|
||||
|
||||
# V1 API (legacy) page sizes
|
||||
V1_GRAPHS_PAGE_SIZE = 250
|
||||
"""Default page size for listing user graphs in v1 API."""
|
||||
|
||||
V1_LIBRARY_AGENTS_PAGE_SIZE = 10
|
||||
"""Default page size for library agents in v1 API."""
|
||||
|
||||
V1_GRAPH_EXECUTIONS_PAGE_SIZE = 25
|
||||
"""Default page size for graph executions in v1 API."""
|
||||
|
||||
# V2 Store API page sizes
|
||||
V2_STORE_AGENTS_PAGE_SIZE = 20
|
||||
"""Default page size for store agents listing."""
|
||||
|
||||
V2_STORE_CREATORS_PAGE_SIZE = 20
|
||||
"""Default page size for store creators listing."""
|
||||
|
||||
V2_STORE_SUBMISSIONS_PAGE_SIZE = 20
|
||||
"""Default page size for user submissions listing."""
|
||||
|
||||
V2_MY_AGENTS_PAGE_SIZE = 20
|
||||
"""Default page size for user's own agents listing."""
|
||||
|
||||
# V2 Library API page sizes
|
||||
V2_LIBRARY_AGENTS_PAGE_SIZE = 10
|
||||
"""Default page size for library agents listing in v2 API."""
|
||||
|
||||
V2_LIBRARY_PRESETS_PAGE_SIZE = 20
|
||||
"""Default page size for library presets listing."""
|
||||
|
||||
# Alternative page sizes (for backward compatibility or special cases)
|
||||
V2_LIBRARY_PRESETS_ALT_PAGE_SIZE = 10
|
||||
"""
|
||||
Alternative page size for library presets.
|
||||
Some clients may use this smaller page size, so cache clearing must handle both.
|
||||
"""
|
||||
|
||||
V2_GRAPH_EXECUTIONS_ALT_PAGE_SIZE = 10
|
||||
"""
|
||||
Alternative page size for graph executions.
|
||||
Some clients may use this smaller page size, so cache clearing must handle both.
|
||||
"""
|
||||
|
||||
# Cache clearing configuration
|
||||
MAX_PAGES_TO_CLEAR = 20
|
||||
"""
|
||||
Maximum number of pages to clear when invalidating paginated caches.
|
||||
This prevents infinite loops while ensuring we clear most cached pages.
|
||||
For users with more than 20 pages, those pages will expire naturally via TTL.
|
||||
"""
|
||||
|
||||
|
||||
def get_page_sizes_for_clearing(
|
||||
primary_page_size: int, alt_page_size: int | None = None
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get all page_size values that should be cleared for a given cache.
|
||||
|
||||
Args:
|
||||
primary_page_size: The main page_size used by the route
|
||||
alt_page_size: Optional alternative page_size if multiple clients use different sizes
|
||||
|
||||
Returns:
|
||||
List of page_size values to clear
|
||||
|
||||
Example:
|
||||
>>> get_page_sizes_for_clearing(20)
|
||||
[20]
|
||||
>>> get_page_sizes_for_clearing(20, 10)
|
||||
[20, 10]
|
||||
"""
|
||||
if alt_page_size is None:
|
||||
return [primary_page_size]
|
||||
return [primary_page_size, alt_page_size]
|
||||
@@ -14,49 +14,19 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
return "test-user-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
return "admin-user-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
return "target-user-id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -64,7 +64,7 @@ class LoginResponse(BaseModel):
|
||||
state_token: str
|
||||
|
||||
|
||||
@router.get("/{provider}/login", summary="Initiate OAuth flow")
|
||||
@router.get("/{provider}/login")
|
||||
async def login(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to initiate an OAuth flow for")
|
||||
@@ -102,7 +102,7 @@ class CredentialsMetaResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||
@router.post("/{provider}/callback")
|
||||
async def callback(
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The target provider for this OAuth exchange")
|
||||
|
||||
@@ -321,6 +321,10 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_graph(
|
||||
graph_id: str,
|
||||
|
||||
154
autogpt_platform/backend/backend/server/routers/cache.py
Normal file
154
autogpt_platform/backend/backend/server/routers/cache.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Cache functions for main V1 API endpoints.
|
||||
|
||||
This module contains all caching decorators and helpers for the V1 API,
|
||||
separated from the main routes for better organization and maintainability.
|
||||
"""
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import get_blocks
|
||||
from backend.util.cache import cached
|
||||
|
||||
# ===== Block Caches =====
|
||||
|
||||
|
||||
# Cache block definitions with costs - they rarely change
|
||||
@cached(maxsize=1, ttl_seconds=3600, shared_cache=True)
|
||||
def get_cached_blocks() -> Sequence[dict]:
|
||||
"""
|
||||
Get cached blocks with thundering herd protection.
|
||||
|
||||
Uses cached decorator to prevent multiple concurrent requests
|
||||
from all executing the expensive block loading operation.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
block_classes = get_blocks()
|
||||
result = []
|
||||
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
# Get costs for this specific block class without creating another instance
|
||||
costs = get_block_cost(block_instance)
|
||||
result.append({**block_instance.to_dict(), "costs": costs})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ===== Graph Caches =====
|
||||
|
||||
|
||||
# Cache user's graphs list for 15 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=900, shared_cache=True)
|
||||
async def get_cached_graphs(
|
||||
user_id: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get user's graphs."""
|
||||
return await graph_db.list_graphs_paginated(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual graph details for 30 minutes
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_graph(
|
||||
graph_id: str,
|
||||
version: int | None,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get graph details."""
|
||||
return await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
|
||||
|
||||
# Cache graph versions for 30 minutes
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_graph_all_versions(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
"""Cached helper to get all versions of a graph."""
|
||||
return await graph_db.get_graph_all_versions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# ===== Execution Caches =====
|
||||
|
||||
|
||||
# Cache graph executions for 10 seconds.
|
||||
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
|
||||
async def get_cached_graph_executions(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get graph executions."""
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache all user executions for 10 seconds.
|
||||
@cached(maxsize=500, ttl_seconds=10, shared_cache=True)
|
||||
async def get_cached_graphs_executions(
|
||||
user_id: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get all user's graph executions."""
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual execution details for 10 seconds.
|
||||
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
|
||||
async def get_cached_graph_execution(
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get graph execution details."""
|
||||
return await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
|
||||
|
||||
# ===== User Preference Caches =====
|
||||
|
||||
|
||||
# Cache user timezone for 1 hour
|
||||
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
|
||||
async def get_cached_user_timezone(user_id: str):
|
||||
"""Cached helper to get user timezone."""
|
||||
user = await user_db.get_user_by_id(user_id)
|
||||
return {"timezone": user.timezone if user else "UTC"}
|
||||
|
||||
|
||||
# Cache user preferences for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_user_preferences(user_id: str):
|
||||
"""Cached helper to get user notification preferences."""
|
||||
return await user_db.get_user_notification_preference(user_id)
|
||||
376
autogpt_platform/backend/backend/server/routers/cache_test.py
Normal file
376
autogpt_platform/backend/backend/server/routers/cache_test.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Tests for cache invalidation in V1 API routes.
|
||||
|
||||
This module tests that caches are properly invalidated when data is modified
|
||||
through POST, PUT, PATCH, and DELETE operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.routers.cache as cache
|
||||
from backend.data import graph as graph_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id():
|
||||
"""Generate a mock user ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_id():
|
||||
"""Generate a mock graph ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestGraphCacheInvalidation:
|
||||
"""Test cache invalidation for graph operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_graph_clears_list_cache(self, mock_user_id):
|
||||
"""Test that creating a graph clears the graphs list cache."""
|
||||
# Setup
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
|
||||
# Pre-populate cache
|
||||
with patch.object(
|
||||
graph_db, "list_graphs_paginated", new_callable=AsyncMock
|
||||
) as mock_list:
|
||||
# Use a simple dict instead of MagicMock to make it pickleable
|
||||
mock_list.return_value = {
|
||||
"graphs": [],
|
||||
"total_count": 0,
|
||||
"page": 1,
|
||||
"page_size": 250,
|
||||
}
|
||||
|
||||
# First call should hit the database
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
assert mock_list.call_count == 1
|
||||
|
||||
# Second call should use cache
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
assert mock_list.call_count == 1 # Still 1, used cache
|
||||
|
||||
# Simulate cache invalidation (what happens in create_new_graph)
|
||||
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
|
||||
|
||||
# Next call should hit database again
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
assert mock_list.call_count == 2 # Incremented, cache was cleared
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_graph_clears_multiple_caches(
|
||||
self, mock_user_id, mock_graph_id
|
||||
):
|
||||
"""Test that deleting a graph clears all related caches."""
|
||||
# Clear all caches first
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
cache.get_cached_graph.cache_clear()
|
||||
cache.get_cached_graph_all_versions.cache_clear()
|
||||
cache.get_cached_graph_executions.cache_clear()
|
||||
|
||||
# Setup mocks
|
||||
with (
|
||||
patch.object(
|
||||
graph_db, "list_graphs_paginated", new_callable=AsyncMock
|
||||
) as mock_list,
|
||||
patch.object(graph_db, "get_graph", new_callable=AsyncMock) as mock_get,
|
||||
patch.object(
|
||||
graph_db, "get_graph_all_versions", new_callable=AsyncMock
|
||||
) as mock_versions,
|
||||
):
|
||||
mock_list.return_value = {
|
||||
"graphs": [],
|
||||
"total_count": 0,
|
||||
"page": 1,
|
||||
"page_size": 250,
|
||||
}
|
||||
mock_get.return_value = {"id": mock_graph_id}
|
||||
mock_versions.return_value = []
|
||||
|
||||
# Pre-populate all caches (use consistent argument style)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
|
||||
initial_calls = {
|
||||
"list": mock_list.call_count,
|
||||
"get": mock_get.call_count,
|
||||
"versions": mock_versions.call_count,
|
||||
}
|
||||
|
||||
# Use cached values (no additional DB calls)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
|
||||
# Verify cache was used
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
assert mock_versions.call_count == initial_calls["versions"]
|
||||
|
||||
# Simulate delete_graph cache invalidation
|
||||
# Use positional arguments for cache_delete to match how we called the functions
|
||||
result1 = cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
|
||||
result2 = cache.get_cached_graph.cache_delete(
|
||||
mock_graph_id, None, mock_user_id
|
||||
)
|
||||
result3 = cache.get_cached_graph_all_versions.cache_delete(
|
||||
mock_graph_id, mock_user_id
|
||||
)
|
||||
|
||||
# Verify that the cache entries were actually deleted
|
||||
assert result1, "Failed to delete graphs cache entry"
|
||||
assert result2, "Failed to delete graph cache entry"
|
||||
assert result3, "Failed to delete graph versions cache entry"
|
||||
|
||||
# Next calls should hit database
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
|
||||
# Verify database was called again
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
assert mock_versions.call_count == initial_calls["versions"] + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph_clears_caches(self, mock_user_id, mock_graph_id):
|
||||
"""Test that updating a graph clears the appropriate caches."""
|
||||
# Clear caches
|
||||
cache.get_cached_graph.cache_clear()
|
||||
cache.get_cached_graph_all_versions.cache_clear()
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
|
||||
with (
|
||||
patch.object(graph_db, "get_graph", new_callable=AsyncMock) as mock_get,
|
||||
patch.object(
|
||||
graph_db, "get_graph_all_versions", new_callable=AsyncMock
|
||||
) as mock_versions,
|
||||
patch.object(
|
||||
graph_db, "list_graphs_paginated", new_callable=AsyncMock
|
||||
) as mock_list,
|
||||
):
|
||||
mock_get.return_value = {"id": mock_graph_id, "version": 1}
|
||||
mock_versions.return_value = [{"version": 1}]
|
||||
mock_list.return_value = {
|
||||
"graphs": [],
|
||||
"total_count": 0,
|
||||
"page": 1,
|
||||
"page_size": 250,
|
||||
}
|
||||
|
||||
# Populate caches
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
|
||||
initial_calls = {
|
||||
"get": mock_get.call_count,
|
||||
"versions": mock_versions.call_count,
|
||||
"list": mock_list.call_count,
|
||||
}
|
||||
|
||||
# Verify cache is being used
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
assert mock_versions.call_count == initial_calls["versions"]
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
|
||||
# Simulate update_graph cache invalidation
|
||||
cache.get_cached_graph.cache_delete(mock_graph_id, None, mock_user_id)
|
||||
cache.get_cached_graph_all_versions.cache_delete(
|
||||
mock_graph_id, mock_user_id
|
||||
)
|
||||
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
|
||||
|
||||
# Next calls should hit database
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
assert mock_versions.call_count == initial_calls["versions"] + 1
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
|
||||
|
||||
class TestUserPreferencesCacheInvalidation:
|
||||
"""Test cache invalidation for user preferences operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_preferences_clears_cache(self, mock_user_id):
|
||||
"""Test that updating preferences clears the preferences cache."""
|
||||
# Clear cache
|
||||
cache.get_cached_user_preferences.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
cache.user_db, "get_user_notification_preference", new_callable=AsyncMock
|
||||
) as mock_get_prefs:
|
||||
mock_prefs = {"email_notifications": True, "push_notifications": False}
|
||||
mock_get_prefs.return_value = mock_prefs
|
||||
|
||||
# First call hits database
|
||||
result1 = await cache.get_cached_user_preferences(mock_user_id)
|
||||
assert mock_get_prefs.call_count == 1
|
||||
assert result1 == mock_prefs
|
||||
|
||||
# Second call uses cache
|
||||
result2 = await cache.get_cached_user_preferences(mock_user_id)
|
||||
assert mock_get_prefs.call_count == 1 # Still 1
|
||||
assert result2 == mock_prefs
|
||||
|
||||
# Simulate update_preferences cache invalidation
|
||||
cache.get_cached_user_preferences.cache_delete(mock_user_id)
|
||||
|
||||
# Change the mock return value to simulate updated preferences
|
||||
mock_prefs_updated = {
|
||||
"email_notifications": False,
|
||||
"push_notifications": True,
|
||||
}
|
||||
mock_get_prefs.return_value = mock_prefs_updated
|
||||
|
||||
# Next call should hit database and get new value
|
||||
result3 = await cache.get_cached_user_preferences(mock_user_id)
|
||||
assert mock_get_prefs.call_count == 2
|
||||
assert result3 == mock_prefs_updated
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timezone_cache_operations(self, mock_user_id):
|
||||
"""Test timezone cache and its operations."""
|
||||
# Clear cache
|
||||
cache.get_cached_user_timezone.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
cache.user_db, "get_user_by_id", new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
# Use a simple object that supports attribute access
|
||||
class MockUser:
|
||||
def __init__(self, timezone):
|
||||
self.timezone = timezone
|
||||
|
||||
mock_user = MockUser("America/New_York")
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# First call hits database
|
||||
result1 = await cache.get_cached_user_timezone(mock_user_id)
|
||||
assert mock_get_user.call_count == 1
|
||||
assert result1["timezone"] == "America/New_York"
|
||||
|
||||
# Second call uses cache
|
||||
result2 = await cache.get_cached_user_timezone(mock_user_id)
|
||||
assert mock_get_user.call_count == 1 # Still 1
|
||||
assert result2["timezone"] == "America/New_York"
|
||||
|
||||
# Clear cache manually (simulating what would happen after update)
|
||||
cache.get_cached_user_timezone.cache_delete(mock_user_id)
|
||||
|
||||
# Change timezone
|
||||
mock_user_updated = MockUser("Europe/London")
|
||||
mock_get_user.return_value = mock_user_updated
|
||||
|
||||
# Next call should hit database
|
||||
result3 = await cache.get_cached_user_timezone(mock_user_id)
|
||||
assert mock_get_user.call_count == 2
|
||||
assert result3["timezone"] == "Europe/London"
|
||||
|
||||
|
||||
class TestExecutionCacheInvalidation:
|
||||
"""Test cache invalidation for execution operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execution_cache_cleared_on_graph_delete(
|
||||
self, mock_user_id, mock_graph_id
|
||||
):
|
||||
"""Test that execution caches are cleared when a graph is deleted."""
|
||||
# Clear cache
|
||||
cache.get_cached_graph_executions.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
cache.execution_db, "get_graph_executions_paginated", new_callable=AsyncMock
|
||||
) as mock_exec:
|
||||
mock_exec.return_value = {
|
||||
"executions": [],
|
||||
"total_count": 0,
|
||||
"page": 1,
|
||||
"page_size": 25,
|
||||
}
|
||||
|
||||
# Populate cache for multiple pages
|
||||
for page in range(1, 4):
|
||||
await cache.get_cached_graph_executions(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
initial_calls = mock_exec.call_count
|
||||
|
||||
# Verify cache is used
|
||||
for page in range(1, 4):
|
||||
await cache.get_cached_graph_executions(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
assert mock_exec.call_count == initial_calls # No new calls
|
||||
|
||||
# Simulate graph deletion clearing execution caches
|
||||
for page in range(1, 10): # Clear more pages as done in delete_graph
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
# Next calls should hit database
|
||||
for page in range(1, 4):
|
||||
await cache.get_cached_graph_executions(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
assert mock_exec.call_count == initial_calls + 3 # 3 new calls
|
||||
|
||||
|
||||
class TestCacheInfo:
|
||||
"""Test cache information and metrics."""
|
||||
|
||||
def test_cache_info_returns_correct_metrics(self):
|
||||
"""Test that cache_info returns correct metrics."""
|
||||
# Clear all caches
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
cache.get_cached_graph.cache_clear()
|
||||
|
||||
# Get initial info
|
||||
info_graphs = cache.get_cached_graphs.cache_info()
|
||||
info_graph = cache.get_cached_graph.cache_info()
|
||||
|
||||
assert info_graphs["size"] == 0
|
||||
assert info_graph["size"] == 0
|
||||
|
||||
# Note: We can't directly test cache population without real async context,
|
||||
# but we can verify the cache_info structure
|
||||
assert "size" in info_graphs
|
||||
assert "maxsize" in info_graphs
|
||||
assert "ttl_seconds" in info_graphs
|
||||
|
||||
def test_cache_clear_removes_all_entries(self):
|
||||
"""Test that cache_clear removes all entries."""
|
||||
# This test verifies the cache_clear method exists and can be called
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
cache.get_cached_graph.cache_clear()
|
||||
cache.get_cached_graph_all_versions.cache_clear()
|
||||
cache.get_cached_graph_executions.cache_clear()
|
||||
cache.get_cached_graphs_executions.cache_clear()
|
||||
cache.get_cached_user_preferences.cache_clear()
|
||||
cache.get_cached_user_timezone.cache_clear()
|
||||
|
||||
# After clear, all caches should be empty
|
||||
assert cache.get_cached_graphs.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graph.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graph_all_versions.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graph_executions.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graphs_executions.cache_info()["size"] == 0
|
||||
assert cache.get_cached_user_preferences.cache_info()["size"] == 0
|
||||
assert cache.get_cached_user_timezone.cache_info()["size"] == 0
|
||||
@@ -28,8 +28,11 @@ from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.server.cache_config as cache_config
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.routers.cache as cache
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as library_db
|
||||
from backend.data import api_key as api_key_db
|
||||
from backend.data import execution as execution_db
|
||||
@@ -39,7 +42,6 @@ from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
UserCredit,
|
||||
get_auto_top_up,
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
@@ -57,7 +59,6 @@ from backend.data.onboarding import (
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
update_user_timezone,
|
||||
@@ -84,7 +85,6 @@ from backend.server.model import (
|
||||
UpdateTimezoneRequest,
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.util.cache import cached
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
@@ -108,6 +108,9 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
# Define the API routes
|
||||
v1_router = APIRouter()
|
||||
|
||||
@@ -166,7 +169,9 @@ async def get_user_timezone_route(
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
user = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
# Use cached timezone for subsequent calls
|
||||
result = await cache.get_cached_user_timezone(user.id)
|
||||
return TimezoneResponse(timezone=result["timezone"])
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -180,6 +185,7 @@ async def update_user_timezone_route(
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
cache.get_cached_user_timezone.cache_delete(user_id)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -192,7 +198,7 @@ async def update_user_timezone_route(
|
||||
async def get_preferences(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> NotificationPreference:
|
||||
preferences = await get_user_notification_preference(user_id)
|
||||
preferences = await cache.get_cached_user_preferences(user_id)
|
||||
return preferences
|
||||
|
||||
|
||||
@@ -207,6 +213,10 @@ async def update_preferences(
|
||||
preferences: NotificationPreferenceDTO = Body(...),
|
||||
) -> NotificationPreference:
|
||||
output = await update_user_notification_preference(user_id, preferences)
|
||||
|
||||
# Clear preferences cache after update
|
||||
cache.get_cached_user_preferences.cache_delete(user_id)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -476,8 +486,7 @@ async def upload_file(
|
||||
async def get_user_credits(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> dict[str, int]:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return {"credits": await user_credit_model.get_credits(user_id)}
|
||||
return {"credits": await _user_credit_model.get_credits(user_id)}
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -489,8 +498,9 @@ async def get_user_credits(
|
||||
async def request_top_up(
|
||||
request: RequestTopUp, user_id: Annotated[str, Security(get_user_id)]
|
||||
):
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
checkout_url = await user_credit_model.top_up_intent(user_id, request.credit_amount)
|
||||
checkout_url = await _user_credit_model.top_up_intent(
|
||||
user_id, request.credit_amount
|
||||
)
|
||||
return {"checkout_url": checkout_url}
|
||||
|
||||
|
||||
@@ -505,8 +515,7 @@ async def refund_top_up(
|
||||
transaction_key: str,
|
||||
metadata: dict[str, str],
|
||||
) -> int:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.top_up_refund(user_id, transaction_key, metadata)
|
||||
return await _user_credit_model.top_up_refund(user_id, transaction_key, metadata)
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
@@ -516,8 +525,7 @@ async def refund_top_up(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
await user_credit_model.fulfill_checkout(user_id=user_id)
|
||||
await _user_credit_model.fulfill_checkout(user_id=user_id)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -531,23 +539,18 @@ async def configure_user_auto_top_up(
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> str:
|
||||
if request.threshold < 0:
|
||||
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
|
||||
raise ValueError("Threshold must be greater than 0")
|
||||
if request.amount < 500 and request.amount != 0:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Amount must be greater than or equal to 500"
|
||||
)
|
||||
if request.amount != 0 and request.amount < request.threshold:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Amount must be greater than or equal to threshold"
|
||||
)
|
||||
raise ValueError("Amount must be greater than or equal to 500")
|
||||
if request.amount < request.threshold:
|
||||
raise ValueError("Amount must be greater than or equal to threshold")
|
||||
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
current_balance = await _user_credit_model.get_credits(user_id)
|
||||
|
||||
if current_balance < request.threshold:
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
await _user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
await _user_credit_model.top_up_credits(user_id, 0)
|
||||
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
@@ -595,13 +598,15 @@ async def stripe_webhook(request: Request):
|
||||
event["type"] == "checkout.session.completed"
|
||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||
):
|
||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
||||
await _user_credit_model.fulfill_checkout(
|
||||
session_id=event["data"]["object"]["id"]
|
||||
)
|
||||
|
||||
if event["type"] == "charge.dispute.created":
|
||||
await UserCredit().handle_dispute(event["data"]["object"])
|
||||
await _user_credit_model.handle_dispute(event["data"]["object"])
|
||||
|
||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
||||
await UserCredit().deduct_credits(event["data"]["object"])
|
||||
await _user_credit_model.deduct_credits(event["data"]["object"])
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
@@ -615,8 +620,7 @@ async def stripe_webhook(request: Request):
|
||||
async def manage_payment_method(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> dict[str, str]:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return {"url": await user_credit_model.create_billing_portal_session(user_id)}
|
||||
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -634,8 +638,7 @@ async def get_credit_history(
|
||||
if transaction_count_limit < 1 or transaction_count_limit > 1000:
|
||||
raise ValueError("Transaction count limit must be between 1 and 1000")
|
||||
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_transaction_history(
|
||||
return await _user_credit_model.get_transaction_history(
|
||||
user_id=user_id,
|
||||
transaction_time_ceiling=transaction_time,
|
||||
transaction_count_limit=transaction_count_limit,
|
||||
@@ -652,8 +655,7 @@ async def get_credit_history(
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_refund_requests(user_id)
|
||||
return await _user_credit_model.get_refund_requests(user_id)
|
||||
|
||||
|
||||
########################################################
|
||||
@@ -674,11 +676,10 @@ class DeleteGraphResponse(TypedDict):
|
||||
async def list_graphs(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphMeta]:
|
||||
paginated_result = await graph_db.list_graphs_paginated(
|
||||
paginated_result = await cache.get_cached_graphs(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
filter_by="active",
|
||||
)
|
||||
return paginated_result.graphs
|
||||
|
||||
@@ -701,13 +702,26 @@ async def get_graph(
|
||||
version: int | None = None,
|
||||
for_export: bool = False,
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id,
|
||||
version,
|
||||
user_id=user_id,
|
||||
for_export=for_export,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
# Use cache for non-export requests
|
||||
if not for_export:
|
||||
graph = await cache.get_cached_graph(
|
||||
graph_id=graph_id,
|
||||
version=version,
|
||||
user_id=user_id,
|
||||
)
|
||||
# If graph not found, clear cache entry as permissions may have changed
|
||||
if not graph:
|
||||
cache.get_cached_graph.cache_delete(
|
||||
graph_id=graph_id, version=version, user_id=user_id
|
||||
)
|
||||
else:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id,
|
||||
version,
|
||||
user_id=user_id,
|
||||
for_export=for_export,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graph
|
||||
@@ -722,7 +736,7 @@ async def get_graph(
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
graphs = await cache.get_cached_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graphs
|
||||
@@ -746,6 +760,26 @@ async def create_new_graph(
|
||||
# as the graph already valid and no sub-graphs are returned back.
|
||||
await graph_db.create_graph(graph, user_id=user_id)
|
||||
await library_db.create_library_agent(graph, user_id=user_id)
|
||||
|
||||
# Clear graphs list cache after creating new graph
|
||||
cache.get_cached_graphs.cache_delete(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
|
||||
)
|
||||
for page in range(1, cache_config.MAX_PAGES_TO_CLEAR):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# Clear my agents cache so user sees new agent immediately
|
||||
import backend.server.v2.store.cache
|
||||
|
||||
backend.server.v2.store.cache._clear_my_agents_cache(user_id)
|
||||
|
||||
return await on_graph_activate(graph, user_id=user_id)
|
||||
|
||||
|
||||
@@ -761,7 +795,32 @@ async def delete_graph(
|
||||
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
|
||||
await on_graph_deactivate(active_version, user_id=user_id)
|
||||
|
||||
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
|
||||
result = DeleteGraphResponse(
|
||||
version_counts=await graph_db.delete_graph(graph_id, user_id=user_id)
|
||||
)
|
||||
|
||||
# Clear caches after deleting graph
|
||||
cache.get_cached_graphs.cache_delete(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
|
||||
)
|
||||
cache.get_cached_graph.cache_delete(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
|
||||
|
||||
# Clear my agents cache so user sees agent removed immediately
|
||||
import backend.server.v2.store.cache
|
||||
|
||||
backend.server.v2.store.cache._clear_my_agents_cache(user_id)
|
||||
|
||||
# Clear library agent by graph_id cache
|
||||
library_cache.get_cached_library_agent_by_graph_id.cache_delete(
|
||||
graph_id=graph_id, user_id=user_id
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
@@ -817,6 +876,18 @@ async def update_graph(
|
||||
include_subgraphs=True,
|
||||
)
|
||||
assert new_graph_version_with_subgraphs # make type checker happy
|
||||
|
||||
# Clear caches after updating graph
|
||||
cache.get_cached_graph.cache_delete(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
|
||||
cache.get_cached_graphs.cache_delete(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return new_graph_version_with_subgraphs
|
||||
|
||||
|
||||
@@ -875,14 +946,36 @@ async def execute_graph(
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
current_balance = await _user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
# Invalidate caches before execution starts so frontend sees fresh data
|
||||
cache.get_cached_graphs_executions.cache_delete(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=cache_config.V1_GRAPHS_PAGE_SIZE,
|
||||
)
|
||||
for page in range(1, cache_config.MAX_PAGES_TO_CLEAR):
|
||||
cache.get_cached_graph_execution.cache_delete(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE,
|
||||
)
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
@@ -895,6 +988,7 @@ async def execute_graph(
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="success")
|
||||
|
||||
return result
|
||||
except GraphValidationError as e:
|
||||
# Record failed graph execution
|
||||
@@ -970,7 +1064,7 @@ async def _stop_graph_run(
|
||||
async def list_graphs_executions(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
paginated_result = await execution_db.get_graph_executions_paginated(
|
||||
paginated_result = await cache.get_cached_graphs_executions(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
@@ -992,7 +1086,7 @@ async def list_graph_executions(
|
||||
25, ge=1, le=100, description="Number of executions per page"
|
||||
),
|
||||
) -> execution_db.GraphExecutionsPaginated:
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
return await cache.get_cached_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
|
||||
@@ -23,13 +23,10 @@ client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user, setup_test_user):
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
# setup_test_user fixture already executed and user is created in database
|
||||
# It returns the user_id which we don't need to await
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
@@ -105,13 +102,13 @@ def test_get_graph_blocks(
|
||||
mock_block.id = "test-block"
|
||||
mock_block.disabled = False
|
||||
|
||||
# Mock get_blocks
|
||||
# Mock get_blocks where it's imported at the top of v1.py
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_blocks",
|
||||
return_value={"test-block": lambda: mock_block},
|
||||
)
|
||||
|
||||
# Mock block costs
|
||||
# Mock block costs where it's imported inside the function
|
||||
mocker.patch(
|
||||
"backend.data.credit.get_block_cost",
|
||||
return_value=[{"cost": 10, "type": "credit"}],
|
||||
@@ -197,12 +194,8 @@ def test_get_user_credits(
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test get user credits endpoint"""
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
|
||||
mock_credit_model.get_credits = AsyncMock(return_value=1000)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
response = client.get("/credits")
|
||||
|
||||
@@ -222,14 +215,10 @@ def test_request_top_up(
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test request top up endpoint"""
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
|
||||
mock_credit_model.top_up_intent = AsyncMock(
|
||||
return_value="https://checkout.example.com/session123"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
request_data = {"credit_amount": 500}
|
||||
|
||||
@@ -272,74 +261,6 @@ def test_get_auto_top_up(
|
||||
)
|
||||
|
||||
|
||||
def test_configure_auto_top_up(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test configure auto top-up endpoint - this test would have caught the enum casting bug"""
|
||||
# Mock the set_auto_top_up function to avoid database operations
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.set_auto_top_up",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
# Mock credit model to avoid Stripe API calls
|
||||
mock_credit_model = mocker.AsyncMock()
|
||||
mock_credit_model.get_credits.return_value = 50 # Current balance below threshold
|
||||
mock_credit_model.top_up_credits.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
# Test data
|
||||
request_data = {
|
||||
"threshold": 100,
|
||||
"amount": 500,
|
||||
}
|
||||
|
||||
response = client.post("/credits/auto-top-up", json=request_data)
|
||||
|
||||
# This should succeed with our fix, but would have failed before with the enum casting error
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "Auto top-up settings updated"
|
||||
|
||||
|
||||
def test_configure_auto_top_up_validation_errors(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test configure auto top-up endpoint validation"""
|
||||
# Mock set_auto_top_up to avoid database operations for successful case
|
||||
mocker.patch("backend.server.routers.v1.set_auto_top_up")
|
||||
|
||||
# Mock credit model to avoid Stripe API calls for the successful case
|
||||
mock_credit_model = mocker.AsyncMock()
|
||||
mock_credit_model.get_credits.return_value = 50
|
||||
mock_credit_model.top_up_credits.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
# Test negative threshold
|
||||
response = client.post(
|
||||
"/credits/auto-top-up", json={"threshold": -1, "amount": 500}
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
# Test amount too small (but not 0)
|
||||
response = client.post(
|
||||
"/credits/auto-top-up", json={"threshold": 100, "amount": 100}
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
# Test amount = 0 (should be allowed)
|
||||
response = client.post("/credits/auto-top-up", json={"threshold": 100, "amount": 0})
|
||||
assert response.status_code == 200 # Should succeed
|
||||
|
||||
|
||||
# Graphs endpoints tests
|
||||
def test_get_graphs(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
|
||||
299
autogpt_platform/backend/backend/server/test_cache_audit.py
Normal file
299
autogpt_platform/backend/backend/server/test_cache_audit.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Complete audit of all @cached functions to verify proper cache invalidation.
|
||||
|
||||
This test systematically checks every @cached function in the codebase
|
||||
to ensure it has appropriate cache invalidation logic when data changes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestCacheInvalidationAudit:
|
||||
"""Audit all @cached functions for proper invalidation."""
|
||||
|
||||
def test_v1_router_caches(self):
|
||||
"""
|
||||
V1 Router cached functions:
|
||||
- _get_cached_blocks(): ✓ NEVER CHANGES (blocks are static in code)
|
||||
"""
|
||||
# No invalidation needed for static data
|
||||
pass
|
||||
|
||||
def test_v1_cache_module_graph_caches(self):
|
||||
"""
|
||||
V1 Cache module graph-related caches:
|
||||
- get_cached_graphs(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py create_graph(), delete_graph(), update_graph_metadata(), stop_graph_execution()
|
||||
|
||||
- get_cached_graph(graph_id, version, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py delete_graph(), update_graph(), delete_graph_execution()
|
||||
|
||||
- get_cached_graph_all_versions(graph_id, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py delete_graph(), update_graph(), delete_graph_execution()
|
||||
|
||||
- get_cached_graph_executions(graph_id, user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py stop_graph_execution()
|
||||
Also cleared in: v2/library/routes/presets.py
|
||||
|
||||
- get_cached_graphs_executions(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py stop_graph_execution()
|
||||
|
||||
- get_cached_graph_execution(graph_exec_id, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py stop_graph_execution()
|
||||
|
||||
ISSUE: All use hardcoded page_size values instead of cache_config constants!
|
||||
"""
|
||||
# Document that v1 routes should migrate to use cache_config
|
||||
pass
|
||||
|
||||
def test_v1_cache_module_user_caches(self):
|
||||
"""
|
||||
V1 Cache module user-related caches:
|
||||
- get_cached_user_timezone(user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py update_user_profile()
|
||||
|
||||
- get_cached_user_preferences(user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py update_user_notification_preferences()
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_v2_store_cache_functions(self):
|
||||
"""
|
||||
V2 Store cached functions:
|
||||
- _get_cached_user_profile(user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v2/store/routes.py update_or_create_profile()
|
||||
|
||||
- _get_cached_store_agents(...): ⚠️ PARTIAL INVALIDATION
|
||||
Cleared in: v2/admin/store_admin_routes.py review_submission() - uses cache_clear()
|
||||
NOT cleared when agents are created/updated!
|
||||
|
||||
- _get_cached_agent_details(username, agent_name): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (15 min)
|
||||
|
||||
- _get_cached_agent_graph(store_listing_version_id): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
|
||||
- _get_cached_store_agent_by_version(store_listing_version_id): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
|
||||
- _get_cached_store_creators(...): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
|
||||
- _get_cached_creator_details(username): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
|
||||
- _get_cached_my_agents(user_id, page, page_size): ❌ NO INVALIDATION
|
||||
NEVER cleared! Users won't see new agents for 5 minutes!
|
||||
CRITICAL BUG: Should be cleared when user creates/deletes agents
|
||||
|
||||
- _get_cached_submissions(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared via: _clear_submissions_cache() helper
|
||||
Called in: create_submission(), edit_submission(), delete_submission()
|
||||
Called in: v2/admin/store_admin_routes.py review_submission()
|
||||
"""
|
||||
# Document critical issues
|
||||
CRITICAL_MISSING_INVALIDATION = [
|
||||
"_get_cached_my_agents - users won't see new agents immediately",
|
||||
]
|
||||
|
||||
# Acceptable TTL-only caches (documented, not asserted):
|
||||
# - _get_cached_agent_details (public data, 15min TTL acceptable)
|
||||
# - _get_cached_agent_graph (immutable data, 1hr TTL acceptable)
|
||||
# - _get_cached_store_agent_by_version (immutable version, 1hr TTL acceptable)
|
||||
# - _get_cached_store_creators (public data, 1hr TTL acceptable)
|
||||
# - _get_cached_creator_details (public data, 1hr TTL acceptable)
|
||||
|
||||
assert (
|
||||
len(CRITICAL_MISSING_INVALIDATION) == 1
|
||||
), "These caches need invalidation logic:\n" + "\n".join(
|
||||
CRITICAL_MISSING_INVALIDATION
|
||||
)
|
||||
|
||||
def test_v2_library_cache_functions(self):
|
||||
"""
|
||||
V2 Library cached functions:
|
||||
- get_cached_library_agents(user_id, page, page_size, ...): ✓ HAS INVALIDATION
|
||||
Cleared in: v1.py create_graph(), stop_graph_execution()
|
||||
Cleared in: v2/library/routes/agents.py add_library_agent(), remove_library_agent()
|
||||
|
||||
- get_cached_library_agent_favorites(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared in: v2/library/routes/agents.py favorite/unfavorite endpoints
|
||||
|
||||
- get_cached_library_agent(library_agent_id, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v2/library/routes/agents.py remove_library_agent()
|
||||
|
||||
- get_cached_library_agent_by_graph_id(graph_id, user_id): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (30 min)
|
||||
Should be cleared when graph is deleted
|
||||
|
||||
- get_cached_library_agent_by_store_version(store_listing_version_id, user_id): ❌ NO INVALIDATION
|
||||
NEVER cleared! Relies only on TTL (1 hour)
|
||||
Probably acceptable as store versions are immutable
|
||||
|
||||
- get_cached_library_presets(user_id, page, page_size): ✓ HAS INVALIDATION
|
||||
Cleared via: _clear_presets_list_cache() helper
|
||||
Called in: v2/library/routes/presets.py preset mutations
|
||||
|
||||
- get_cached_library_preset(preset_id, user_id): ✓ HAS INVALIDATION
|
||||
Cleared in: v2/library/routes/presets.py preset mutations
|
||||
|
||||
ISSUE: Clearing uses hardcoded page_size values (10 and 20) instead of cache_config!
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_immutable_singleton_caches(self):
|
||||
"""
|
||||
Caches that never need invalidation (singleton or immutable):
|
||||
- get_webhook_block_ids(): ✓ STATIC (blocks in code)
|
||||
- get_io_block_ids(): ✓ STATIC (blocks in code)
|
||||
- get_supabase(): ✓ CLIENT INSTANCE (no invalidation needed)
|
||||
- get_async_supabase(): ✓ CLIENT INSTANCE (no invalidation needed)
|
||||
- _get_all_providers(): ✓ STATIC CONFIG (providers in code)
|
||||
- get_redis(): ✓ CLIENT INSTANCE (no invalidation needed)
|
||||
- load_webhook_managers(): ✓ STATIC (managers in code)
|
||||
- load_all_blocks(): ✓ STATIC (blocks in code)
|
||||
- get_cached_blocks(): ✓ STATIC (blocks in code)
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_feature_flag_cache(self):
|
||||
"""
|
||||
Feature flag cache:
|
||||
- _fetch_user_context_data(user_id): ⚠️ LONG TTL
|
||||
TTL: 24 hours
|
||||
NO INVALIDATION
|
||||
|
||||
This is probably acceptable as user context changes infrequently.
|
||||
However, if user metadata changes, they won't see updated flags for 24 hours.
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_onboarding_cache(self):
|
||||
"""
|
||||
Onboarding cache:
|
||||
- onboarding_enabled(): ⚠️ NO INVALIDATION
|
||||
TTL: 5 minutes
|
||||
NO INVALIDATION
|
||||
|
||||
Should probably be cleared when store agents are added/removed.
|
||||
But 5min TTL is acceptable for this use case.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TestCacheInvalidationPageSizeConsistency:
|
||||
"""Test that all cache_delete calls use consistent page_size values."""
|
||||
|
||||
def test_v1_routes_hardcoded_page_sizes(self):
|
||||
"""
|
||||
V1 routes use hardcoded page_size values that should migrate to cache_config:
|
||||
|
||||
❌ page_size=250 for graphs:
|
||||
- v1.py line 765: cache.get_cached_graphs.cache_delete(user_id, page=1, page_size=250)
|
||||
- v1.py line 791: cache.get_cached_graphs.cache_delete(user_id, page=1, page_size=250)
|
||||
- v1.py line 859: cache.get_cached_graphs.cache_delete(user_id, page=1, page_size=250)
|
||||
- v1.py line 929: cache.get_cached_graphs_executions.cache_delete(user_id, page=1, page_size=250)
|
||||
|
||||
❌ page_size=10 for library agents:
|
||||
- v1.py line 768: library_cache.get_cached_library_agents.cache_delete(..., page_size=10)
|
||||
- v1.py line 940: library_cache.get_cached_library_agents.cache_delete(..., page_size=10)
|
||||
|
||||
❌ page_size=25 for graph executions:
|
||||
- v1.py line 937: cache.get_cached_graph_executions.cache_delete(..., page_size=25)
|
||||
|
||||
RECOMMENDATION: Create constants in cache_config and migrate v1 routes to use them.
|
||||
"""
|
||||
from backend.server import cache_config
|
||||
|
||||
# These constants exist but aren't used in v1 routes yet
|
||||
assert cache_config.V1_GRAPHS_PAGE_SIZE == 250
|
||||
assert cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE == 10
|
||||
assert cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE == 25
|
||||
|
||||
def test_v2_library_routes_hardcoded_page_sizes(self):
|
||||
"""
|
||||
V2 library routes use hardcoded page_size values:
|
||||
|
||||
❌ v2/library/routes/agents.py:
|
||||
- line 233: cache_delete(..., page_size=10)
|
||||
|
||||
❌ v2/library/routes/presets.py _clear_presets_list_cache():
|
||||
- Clears BOTH page_size=10 AND page_size=20
|
||||
- This suggests different consumers use different page sizes
|
||||
|
||||
❌ v2/library/routes/presets.py:
|
||||
- line 449: cache_delete(..., page_size=10)
|
||||
- line 452: cache_delete(..., page_size=25)
|
||||
|
||||
RECOMMENDATION: Migrate to use cache_config constants.
|
||||
"""
|
||||
from backend.server import cache_config
|
||||
|
||||
# Constants exist for library
|
||||
assert cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE == 10
|
||||
assert cache_config.V2_LIBRARY_PRESETS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_LIBRARY_PRESETS_ALT_PAGE_SIZE == 10
|
||||
|
||||
def test_only_page_1_cleared_risk(self):
|
||||
"""
|
||||
Document cache_delete calls that only clear page=1.
|
||||
|
||||
RISKY PATTERN: Many cache_delete calls only clear page=1:
|
||||
- v1.py create_graph(): Only clears page=1 of graphs
|
||||
- v1.py delete_graph(): Only clears page=1 of graphs
|
||||
- v1.py update_graph_metadata(): Only clears page=1 of graphs
|
||||
- v1.py stop_graph_execution(): Only clears page=1 of executions
|
||||
|
||||
PROBLEM: If user has > 1 page, subsequent pages show stale data until TTL expires.
|
||||
|
||||
SOLUTIONS:
|
||||
1. Use cache_clear() to clear all pages (nuclear option)
|
||||
2. Loop through multiple pages like _clear_submissions_cache does
|
||||
3. Accept TTL-based expiry for pages 2+ (current approach)
|
||||
|
||||
Current approach is probably acceptable given TTL values are reasonable.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TestCriticalCacheBugs:
|
||||
"""Document critical cache bugs that need fixing."""
|
||||
|
||||
def test_my_agents_cache_never_cleared(self):
|
||||
"""
|
||||
CRITICAL BUG: _get_cached_my_agents is NEVER cleared!
|
||||
|
||||
Impact:
|
||||
- User creates a new agent → Won't see it in "My Agents" for 5 minutes
|
||||
- User deletes an agent → Still see it in "My Agents" for 5 minutes
|
||||
|
||||
Fix needed:
|
||||
1. Create _clear_my_agents_cache() helper (like _clear_submissions_cache)
|
||||
2. Call it from v1.py create_graph() and delete_graph()
|
||||
3. Use cache_config.V2_MY_AGENTS_PAGE_SIZE constant
|
||||
|
||||
Location: v2/store/cache.py line 120
|
||||
"""
|
||||
# This documents the bug
|
||||
NEEDS_CACHE_CLEARING = "_get_cached_my_agents"
|
||||
assert NEEDS_CACHE_CLEARING == "_get_cached_my_agents"
|
||||
|
||||
def test_library_agent_by_graph_id_never_cleared(self):
|
||||
"""
|
||||
BUG: get_cached_library_agent_by_graph_id is NEVER cleared!
|
||||
|
||||
Impact:
|
||||
- User deletes a graph → Library still shows it's available for 30 minutes
|
||||
|
||||
Fix needed:
|
||||
- Clear in v1.py delete_graph()
|
||||
- Clear in v2/library/routes/agents.py remove_library_agent()
|
||||
|
||||
Location: v2/library/cache.py line 59
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite to verify cache_config constants are being used correctly.
|
||||
|
||||
This ensures that the centralized cache_config.py constants are actually
|
||||
used throughout the codebase, not just defined.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server import cache_config
|
||||
|
||||
|
||||
class TestCacheConfigConstants:
|
||||
"""Verify cache_config constants have expected values."""
|
||||
|
||||
def test_v2_store_page_sizes(self):
|
||||
"""Test V2 Store API page size constants."""
|
||||
assert cache_config.V2_STORE_AGENTS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_STORE_CREATORS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_MY_AGENTS_PAGE_SIZE == 20
|
||||
|
||||
def test_v2_library_page_sizes(self):
|
||||
"""Test V2 Library API page size constants."""
|
||||
assert cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE == 10
|
||||
assert cache_config.V2_LIBRARY_PRESETS_PAGE_SIZE == 20
|
||||
assert cache_config.V2_LIBRARY_PRESETS_ALT_PAGE_SIZE == 10
|
||||
|
||||
def test_v1_page_sizes(self):
|
||||
"""Test V1 API page size constants."""
|
||||
assert cache_config.V1_GRAPHS_PAGE_SIZE == 250
|
||||
assert cache_config.V1_LIBRARY_AGENTS_PAGE_SIZE == 10
|
||||
assert cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE == 25
|
||||
|
||||
def test_cache_clearing_config(self):
|
||||
"""Test cache clearing configuration."""
|
||||
assert cache_config.MAX_PAGES_TO_CLEAR == 20
|
||||
|
||||
def test_get_page_sizes_for_clearing_helper(self):
|
||||
"""Test the helper function for getting page sizes to clear."""
|
||||
# Single page size
|
||||
result = cache_config.get_page_sizes_for_clearing(20)
|
||||
assert result == [20]
|
||||
|
||||
# Multiple page sizes
|
||||
result = cache_config.get_page_sizes_for_clearing(20, 10)
|
||||
assert result == [20, 10]
|
||||
|
||||
# With None alt_page_size
|
||||
result = cache_config.get_page_sizes_for_clearing(20, None)
|
||||
assert result == [20]
|
||||
|
||||
|
||||
class TestCacheConfigUsage:
|
||||
"""Test that cache_config constants are actually used in the code."""
|
||||
|
||||
def test_store_routes_import_cache_config(self):
|
||||
"""Verify store routes imports cache_config."""
|
||||
import backend.server.v2.store.routes as store_routes
|
||||
|
||||
# Check that cache_config is imported
|
||||
assert hasattr(store_routes, "backend")
|
||||
assert hasattr(store_routes.backend.server, "cache_config")
|
||||
|
||||
def test_store_cache_uses_constants(self):
|
||||
"""Verify store cache module uses cache_config constants."""
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
|
||||
# Check the module imports cache_config
|
||||
assert hasattr(store_cache, "backend")
|
||||
assert hasattr(store_cache.backend.server, "cache_config")
|
||||
|
||||
# The _clear_submissions_cache function should use the constant
|
||||
import inspect
|
||||
|
||||
source = inspect.getsource(store_cache._clear_submissions_cache)
|
||||
assert (
|
||||
"cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE" in source
|
||||
), "_clear_submissions_cache must use cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE"
|
||||
assert (
|
||||
"cache_config.MAX_PAGES_TO_CLEAR" in source
|
||||
), "_clear_submissions_cache must use cache_config.MAX_PAGES_TO_CLEAR"
|
||||
|
||||
def test_admin_routes_use_constants(self):
|
||||
"""Verify admin routes use cache_config constants."""
|
||||
import backend.server.v2.admin.store_admin_routes as admin_routes
|
||||
|
||||
# Check that cache_config is imported
|
||||
assert hasattr(admin_routes, "backend")
|
||||
assert hasattr(admin_routes.backend.server, "cache_config")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive test suite for cache invalidation consistency across the entire backend.
|
||||
|
||||
This test file identifies ALL locations where cache_delete is called with hardcoded
|
||||
parameters (especially page_size) and ensures they match the corresponding route defaults.
|
||||
|
||||
CRITICAL: If any test in this file fails, it means cache invalidation will be broken
|
||||
and users will see stale data after mutations.
|
||||
|
||||
Key problem areas identified:
|
||||
1. v1.py routes: Uses page_size=250 for graphs, but cache clearing uses page_size=250 ✓
|
||||
2. v1.py routes: Uses page_size=10 for library agents clearing
|
||||
3. v2/library routes: Uses page_size=10 for library agents clearing
|
||||
4. v2/store routes: Uses page_size=20 for submissions clearing (in _clear_submissions_cache)
|
||||
5. v2/library presets: Uses page_size=10 AND page_size=20 for presets (dual clearing)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestCacheInvalidationConsistency:
|
||||
"""Test that all cache_delete calls use correct parameters matching route defaults."""
|
||||
|
||||
def test_v1_graphs_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v1 graphs routes use consistent page_size.
|
||||
|
||||
Locations that must match:
|
||||
- routes/v1.py line 682: default page_size=250
|
||||
- routes/v1.py line 765: cache_delete with page_size=250
|
||||
- routes/v1.py line 791: cache_delete with page_size=250
|
||||
- routes/v1.py line 859: cache_delete with page_size=250
|
||||
- routes/v1.py line 929: cache_delete with page_size=250
|
||||
- routes/v1.py line 1034: default page_size=250
|
||||
"""
|
||||
V1_GRAPHS_DEFAULT_PAGE_SIZE = 250
|
||||
|
||||
# This is the expected value - if this test fails, check all the above locations
|
||||
assert V1_GRAPHS_DEFAULT_PAGE_SIZE == 250, (
|
||||
"If you changed the default page_size for v1 graphs, you must update:\n"
|
||||
"1. routes/v1.py list_graphs() default parameter\n"
|
||||
"2. routes/v1.py create_graph() cache_delete call\n"
|
||||
"3. routes/v1.py delete_graph() cache_delete call\n"
|
||||
"4. routes/v1.py update_graph_metadata() cache_delete call\n"
|
||||
"5. routes/v1.py stop_graph_execution() cache_delete call\n"
|
||||
"6. routes/v1.py list_graph_run_events() default parameter"
|
||||
)
|
||||
|
||||
def test_v1_library_agents_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v1 library agents cache clearing uses consistent page_size.
|
||||
|
||||
Locations that must match:
|
||||
- routes/v1.py line 768: cache_delete with page_size=10
|
||||
- routes/v1.py line 940: cache_delete with page_size=10
|
||||
- v2/library/routes/agents.py line 233: cache_delete with page_size=10
|
||||
|
||||
WARNING: These hardcode page_size=10 but we need to verify this matches
|
||||
the actual page_size used when fetching library agents!
|
||||
"""
|
||||
V1_LIBRARY_AGENTS_CLEARING_PAGE_SIZE = 10
|
||||
|
||||
assert V1_LIBRARY_AGENTS_CLEARING_PAGE_SIZE == 10, (
|
||||
"If you changed the library agents clearing page_size, you must update:\n"
|
||||
"1. routes/v1.py create_graph() cache clearing loop\n"
|
||||
"2. routes/v1.py stop_graph_execution() cache clearing loop\n"
|
||||
"3. v2/library/routes/agents.py add_library_agent() cache clearing loop"
|
||||
)
|
||||
|
||||
# TODO: This should be verified against the actual default used in library routes
|
||||
|
||||
def test_v1_graph_executions_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v1 graph executions cache clearing uses consistent page_size.
|
||||
|
||||
Locations:
|
||||
- routes/v1.py line 937: cache_delete with page_size=25
|
||||
- v2/library/routes/presets.py line 449: cache_delete with page_size=10
|
||||
- v2/library/routes/presets.py line 452: cache_delete with page_size=25
|
||||
"""
|
||||
V1_GRAPH_EXECUTIONS_CLEARING_PAGE_SIZE = 25
|
||||
|
||||
# Note: presets.py clears BOTH page_size=10 AND page_size=25
|
||||
# This suggests there may be multiple consumers with different page sizes
|
||||
assert V1_GRAPH_EXECUTIONS_CLEARING_PAGE_SIZE == 25
|
||||
|
||||
def test_v2_store_submissions_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v2 store submissions use consistent page_size.
|
||||
|
||||
Locations that must match:
|
||||
- v2/store/routes.py line 484: default page_size=20
|
||||
- v2/store/cache.py line 18: _clear_submissions_cache uses page_size=20
|
||||
|
||||
This is already tested in test_cache_delete.py but documented here for completeness.
|
||||
"""
|
||||
V2_STORE_SUBMISSIONS_DEFAULT_PAGE_SIZE = 20
|
||||
V2_STORE_SUBMISSIONS_CLEARING_PAGE_SIZE = 20
|
||||
|
||||
assert (
|
||||
V2_STORE_SUBMISSIONS_DEFAULT_PAGE_SIZE
|
||||
== V2_STORE_SUBMISSIONS_CLEARING_PAGE_SIZE
|
||||
), (
|
||||
"The default page_size for store submissions must match the hardcoded value in _clear_submissions_cache!\n"
|
||||
"Update both:\n"
|
||||
"1. v2/store/routes.py get_submissions() default parameter\n"
|
||||
"2. v2/store/cache.py _clear_submissions_cache() hardcoded page_size"
|
||||
)
|
||||
|
||||
def test_v2_library_presets_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test v2 library presets cache clearing uses consistent page_size.
|
||||
|
||||
Locations:
|
||||
- v2/library/routes/presets.py line 36: cache_delete with page_size=10
|
||||
- v2/library/routes/presets.py line 39: cache_delete with page_size=20
|
||||
|
||||
This route clears BOTH page_size=10 and page_size=20, suggesting multiple consumers.
|
||||
"""
|
||||
V2_LIBRARY_PRESETS_CLEARING_PAGE_SIZES = [10, 20]
|
||||
|
||||
assert 10 in V2_LIBRARY_PRESETS_CLEARING_PAGE_SIZES
|
||||
assert 20 in V2_LIBRARY_PRESETS_CLEARING_PAGE_SIZES
|
||||
|
||||
# TODO: Verify these match the actual page_size defaults used in preset routes
|
||||
|
||||
def test_cache_clearing_helper_functions_documented(self):
|
||||
"""
|
||||
Document all cache clearing helper functions and their hardcoded parameters.
|
||||
|
||||
Helper functions that wrap cache_delete with hardcoded params:
|
||||
1. v2/store/cache.py::_clear_submissions_cache() - hardcodes page_size=20, num_pages=20
|
||||
2. v2/library/routes/presets.py::_clear_presets_list_cache() - hardcodes page_size=10 AND 20, num_pages=20
|
||||
|
||||
These helpers are DANGEROUS because:
|
||||
- They hide the hardcoded parameters
|
||||
- They loop through multiple pages with hardcoded page_size
|
||||
- If the route default changes, these won't clear the right cache entries
|
||||
"""
|
||||
HELPER_FUNCTIONS = {
|
||||
"_clear_submissions_cache": {
|
||||
"file": "v2/store/cache.py",
|
||||
"page_size": 20,
|
||||
"num_pages": 20,
|
||||
"risk": "HIGH - single page_size, could miss entries if default changes",
|
||||
},
|
||||
"_clear_presets_list_cache": {
|
||||
"file": "v2/library/routes/presets.py",
|
||||
"page_size": [10, 20],
|
||||
"num_pages": 20,
|
||||
"risk": "MEDIUM - clears multiple page_sizes, but could still miss new ones",
|
||||
},
|
||||
}
|
||||
|
||||
assert (
|
||||
len(HELPER_FUNCTIONS) == 2
|
||||
), "If you add new cache clearing helper functions, document them here!"
|
||||
|
||||
def test_cache_delete_without_page_loops_are_risky(self):
|
||||
"""
|
||||
Document cache_delete calls that clear only page=1 (risky if there are multiple pages).
|
||||
|
||||
Single page cache_delete calls:
|
||||
- routes/v1.py line 765: Only clears page=1 with page_size=250
|
||||
- routes/v1.py line 791: Only clears page=1 with page_size=250
|
||||
- routes/v1.py line 859: Only clears page=1 with page_size=250
|
||||
|
||||
These are RISKY because:
|
||||
- If a user has more than one page of graphs, pages 2+ won't be invalidated
|
||||
- User could see stale data on pagination
|
||||
|
||||
RECOMMENDATION: Use cache_clear() or loop through multiple pages like
|
||||
_clear_submissions_cache does.
|
||||
"""
|
||||
SINGLE_PAGE_CLEARS = [
|
||||
"routes/v1.py line 765: create_graph clears only page=1",
|
||||
"routes/v1.py line 791: delete_graph clears only page=1",
|
||||
"routes/v1.py line 859: update_graph_metadata clears only page=1",
|
||||
]
|
||||
|
||||
# This test documents the issue but doesn't fail
|
||||
# Consider this a TODO to fix these cache clearing strategies
|
||||
assert (
|
||||
len(SINGLE_PAGE_CLEARS) >= 3
|
||||
), "These cache_delete calls should probably loop through multiple pages"
|
||||
|
||||
def test_all_cached_functions_have_proper_invalidation(self):
|
||||
"""
|
||||
Verify all @cached functions have corresponding cache_delete calls.
|
||||
|
||||
Functions with proper invalidation:
|
||||
✓ get_cached_user_profile - cleared on profile update
|
||||
✓ get_cached_store_agents - cleared on admin review (cache_clear)
|
||||
✓ get_cached_submissions - cleared via _clear_submissions_cache helper
|
||||
✓ get_cached_graphs - cleared on graph mutations
|
||||
✓ get_cached_library_agents - cleared on library changes
|
||||
|
||||
Functions that might not have proper invalidation:
|
||||
? get_cached_agent_details - not explicitly cleared
|
||||
? get_cached_store_creators - not explicitly cleared
|
||||
? get_cached_my_agents - not explicitly cleared (no helper function exists!)
|
||||
|
||||
This is a documentation test - actual verification requires code analysis.
|
||||
"""
|
||||
NEEDS_VERIFICATION = [
|
||||
"get_cached_agent_details",
|
||||
"get_cached_store_creators",
|
||||
"get_cached_my_agents", # NO CLEARING FUNCTION EXISTS!
|
||||
]
|
||||
|
||||
assert "get_cached_my_agents" in NEEDS_VERIFICATION, (
|
||||
"get_cached_my_agents has no cache clearing logic - this is a BUG!\n"
|
||||
"When a user creates/deletes an agent, their 'my agents' list won't update."
|
||||
)
|
||||
|
||||
|
||||
class TestCacheKeyParameterOrdering:
|
||||
"""
|
||||
Test that cache_delete calls use the same parameter order as the @cached function.
|
||||
|
||||
The @cached decorator uses function signature order to create cache keys.
|
||||
cache_delete must use the exact same order or it won't find the cached entry!
|
||||
"""
|
||||
|
||||
def test_cached_function_parameter_order_matters(self):
|
||||
"""
|
||||
Document that parameter order in cache_delete must match @cached function signature.
|
||||
|
||||
Example from v2/store/cache.py:
|
||||
|
||||
@cached(...)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
...
|
||||
|
||||
CORRECT: _get_cached_submissions.cache_delete(user_id, page=1, page_size=20)
|
||||
WRONG: _get_cached_submissions.cache_delete(page=1, user_id=user_id, page_size=20)
|
||||
|
||||
The cached decorator generates keys based on the POSITIONAL order, so parameter
|
||||
order must match between the function definition and cache_delete call.
|
||||
"""
|
||||
# This is a documentation test - no assertion needed
|
||||
# Real verification requires inspecting each cache_delete call
|
||||
pass
|
||||
|
||||
def test_named_parameters_vs_positional_in_cache_delete(self):
|
||||
"""
|
||||
Document best practice: use named parameters in cache_delete for safety.
|
||||
|
||||
Good practice seen in codebase:
|
||||
- cache.get_cached_graphs.cache_delete(user_id=user_id, page=1, page_size=250)
|
||||
- library_cache.get_cached_library_agents.cache_delete(user_id=user_id, page=page, page_size=10)
|
||||
|
||||
This is safer than positional arguments because:
|
||||
1. More readable
|
||||
2. Less likely to get order wrong
|
||||
3. Self-documenting what each parameter means
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -457,7 +457,8 @@ async def test_api_key_with_unicode_characters_normalization_attack(mock_request
|
||||
"""Test that Unicode normalization doesn't bypass validation."""
|
||||
# Create auth with composed Unicode character
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key", expected_token="café" # é is composed
|
||||
header_name="X-API-Key",
|
||||
expected_token="café", # é is composed
|
||||
)
|
||||
|
||||
# Try with decomposed version (c + a + f + e + ´)
|
||||
@@ -522,8 +523,8 @@ async def test_api_keys_with_newline_variations(mock_request):
|
||||
"valid\r\ntoken", # Windows newline
|
||||
"valid\rtoken", # Mac newline
|
||||
"valid\x85token", # NEL (Next Line)
|
||||
"valid\x0Btoken", # Vertical Tab
|
||||
"valid\x0Ctoken", # Form Feed
|
||||
"valid\x0btoken", # Vertical Tab
|
||||
"valid\x0ctoken", # Form Feed
|
||||
]
|
||||
|
||||
for api_key in newline_variations:
|
||||
|
||||
@@ -23,7 +23,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoModManager:
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._load_config()
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
@@ -31,8 +33,7 @@ async def add_user_credits(
|
||||
logger.info(
|
||||
f"Admin user {admin_user_id} is adding {amount} credits to user {user_id}"
|
||||
)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
new_balance, transaction_key = await user_credit_model._add_transaction(
|
||||
new_balance, transaction_key = await _user_credit_model._add_transaction(
|
||||
user_id,
|
||||
amount,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -7,12 +7,12 @@ import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from prisma import Json
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
|
||||
import backend.server.v2.admin.model as admin_model
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
@@ -37,14 +37,12 @@ def test_add_user_credits_success(
|
||||
) -> None:
|
||||
"""Test successful credit addition by admin"""
|
||||
# Mock the credit model
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model = mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
|
||||
)
|
||||
mock_credit_model._add_transaction = AsyncMock(
|
||||
return_value=(1500, "transaction-123-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"user_id": target_user_id,
|
||||
@@ -64,17 +62,11 @@ def test_add_user_credits_success(
|
||||
call_args = mock_credit_model._add_transaction.call_args
|
||||
assert call_args[0] == (target_user_id, 500)
|
||||
assert call_args[1]["transaction_type"] == prisma.enums.CreditTransactionType.GRANT
|
||||
# Check that metadata is a SafeJson object with the expected content
|
||||
assert isinstance(call_args[1]["metadata"], SafeJson)
|
||||
actual_metadata = call_args[1]["metadata"]
|
||||
expected_data = {
|
||||
"admin_id": admin_user_id,
|
||||
"reason": "Test credit grant for debugging",
|
||||
}
|
||||
|
||||
# SafeJson inherits from Json which stores parsed data in the .data attribute
|
||||
assert actual_metadata.data["admin_id"] == expected_data["admin_id"]
|
||||
assert actual_metadata.data["reason"] == expected_data["reason"]
|
||||
# Check that metadata is a Json object with the expected content
|
||||
assert isinstance(call_args[1]["metadata"], Json)
|
||||
assert call_args[1]["metadata"] == Json(
|
||||
{"admin_id": admin_user_id, "reason": "Test credit grant for debugging"}
|
||||
)
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(
|
||||
@@ -89,14 +81,12 @@ def test_add_user_credits_negative_amount(
|
||||
) -> None:
|
||||
"""Test credit deduction by admin (negative amount)"""
|
||||
# Mock the credit model
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model = mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
|
||||
)
|
||||
mock_credit_model._add_transaction = AsyncMock(
|
||||
return_value=(200, "transaction-456-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"user_id": "target-user-id",
|
||||
|
||||
@@ -7,7 +7,8 @@ import fastapi
|
||||
import fastapi.responses
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.cache_config
|
||||
import backend.server.v2.store.cache
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
@@ -30,7 +31,7 @@ async def get_admin_listings_with_versions(
|
||||
status: typing.Optional[prisma.enums.SubmissionStatus] = None,
|
||||
search: typing.Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
page_size: int = backend.server.cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE,
|
||||
):
|
||||
"""
|
||||
Get store listings with their version history for admins.
|
||||
@@ -87,11 +88,6 @@ async def review_submission(
|
||||
StoreSubmission with updated review information
|
||||
"""
|
||||
try:
|
||||
already_approved = (
|
||||
await backend.server.v2.store.db.check_submission_already_approved(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
)
|
||||
)
|
||||
submission = await backend.server.v2.store.db.review_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
is_approved=request.is_approved,
|
||||
@@ -99,11 +95,8 @@ async def review_submission(
|
||||
internal_comments=request.internal_comments or "",
|
||||
reviewer_id=user_id,
|
||||
)
|
||||
|
||||
state_changed = already_approved != request.is_approved
|
||||
# Clear caches when the request is approved as it updates what is shown on the store
|
||||
if state_changed:
|
||||
store_cache.clear_all_caches()
|
||||
backend.server.v2.store.cache._clear_submissions_cache(submission.user_id)
|
||||
backend.server.v2.store.cache._get_cached_store_agents.cache_clear()
|
||||
return submission
|
||||
except Exception as e:
|
||||
logger.exception("Error reviewing submission: %s", e)
|
||||
|
||||
@@ -118,17 +118,6 @@ def get_blocks(
|
||||
)
|
||||
|
||||
|
||||
def get_block_by_id(block_id: str) -> BlockInfo | None:
|
||||
"""
|
||||
Get a specific block by its ID.
|
||||
"""
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
if block.id == block_id:
|
||||
return block.get_info()
|
||||
return None
|
||||
|
||||
|
||||
def search_blocks(
|
||||
include_blocks: bool = True,
|
||||
include_integrations: bool = True,
|
||||
|
||||
@@ -53,6 +53,16 @@ class ProviderResponse(BaseModel):
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
# Search
|
||||
class SearchRequest(BaseModel):
|
||||
search_query: str | None = None
|
||||
filter: list[FilterType] | None = None
|
||||
by_creator: list[str] | None = None
|
||||
search_id: str | None = None
|
||||
page: int | None = None
|
||||
page_size: int | None = None
|
||||
|
||||
|
||||
class SearchBlocksResponse(BaseModel):
|
||||
blocks: BlockResponse
|
||||
total_block_count: int
|
||||
|
||||
@@ -110,25 +110,6 @@ async def get_blocks(
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/blocks/batch",
|
||||
summary="Get specific blocks",
|
||||
response_model=list[builder_model.BlockInfo],
|
||||
)
|
||||
async def get_specific_blocks(
|
||||
block_ids: Annotated[list[str], fastapi.Query()],
|
||||
) -> list[builder_model.BlockInfo]:
|
||||
"""
|
||||
Get specific blocks by their IDs.
|
||||
"""
|
||||
blocks = []
|
||||
for block_id in block_ids:
|
||||
block = builder_db.get_block_by_id(block_id)
|
||||
if block:
|
||||
blocks.append(block)
|
||||
return blocks
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
summary="Get Builder integration providers",
|
||||
@@ -147,34 +128,30 @@ async def get_providers(
|
||||
)
|
||||
|
||||
|
||||
# Not using post method because on frontend, orval doesn't support Infinite Query with POST method.
|
||||
@router.get(
|
||||
@router.post(
|
||||
"/search",
|
||||
summary="Builder search",
|
||||
tags=["store", "private"],
|
||||
response_model=builder_model.SearchResponse,
|
||||
)
|
||||
async def search(
|
||||
options: builder_model.SearchRequest,
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
search_query: Annotated[str | None, fastapi.Query()] = None,
|
||||
filter: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||
search_id: Annotated[str | None, fastapi.Query()] = None,
|
||||
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||
page: Annotated[int, fastapi.Query()] = 1,
|
||||
page_size: Annotated[int, fastapi.Query()] = 50,
|
||||
) -> builder_model.SearchResponse:
|
||||
"""
|
||||
Search for blocks (including integrations), marketplace agents, and user library agents.
|
||||
"""
|
||||
# If no filters are provided, then we will return all types
|
||||
if not filter:
|
||||
filter = [
|
||||
if not options.filter:
|
||||
options.filter = [
|
||||
"blocks",
|
||||
"integrations",
|
||||
"marketplace_agents",
|
||||
"my_agents",
|
||||
]
|
||||
search_query = sanitize_query(search_query)
|
||||
options.search_query = sanitize_query(options.search_query)
|
||||
options.page = options.page or 1
|
||||
options.page_size = options.page_size or 50
|
||||
|
||||
# Blocks&Integrations
|
||||
blocks = builder_model.SearchBlocksResponse(
|
||||
@@ -185,13 +162,13 @@ async def search(
|
||||
total_block_count=0,
|
||||
total_integration_count=0,
|
||||
)
|
||||
if "blocks" in filter or "integrations" in filter:
|
||||
if "blocks" in options.filter or "integrations" in options.filter:
|
||||
blocks = builder_db.search_blocks(
|
||||
include_blocks="blocks" in filter,
|
||||
include_integrations="integrations" in filter,
|
||||
query=search_query or "",
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
include_blocks="blocks" in options.filter,
|
||||
include_integrations="integrations" in options.filter,
|
||||
query=options.search_query or "",
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
)
|
||||
|
||||
# Library Agents
|
||||
@@ -199,12 +176,12 @@ async def search(
|
||||
agents=[],
|
||||
pagination=Pagination.empty(),
|
||||
)
|
||||
if "my_agents" in filter:
|
||||
if "my_agents" in options.filter:
|
||||
my_agents = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search_term=options.search_query,
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
)
|
||||
|
||||
# Marketplace Agents
|
||||
@@ -212,12 +189,12 @@ async def search(
|
||||
agents=[],
|
||||
pagination=Pagination.empty(),
|
||||
)
|
||||
if "marketplace_agents" in filter:
|
||||
if "marketplace_agents" in options.filter:
|
||||
marketplace_agents = await store_db.get_store_agents(
|
||||
creators=by_creator,
|
||||
search_query=search_query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
creators=options.by_creator,
|
||||
search_query=options.search_query,
|
||||
page=options.page,
|
||||
page_size=options.page_size,
|
||||
)
|
||||
|
||||
more_pages = False
|
||||
@@ -237,7 +214,7 @@ async def search(
|
||||
"marketplace_agents": marketplace_agents.pagination.total_items,
|
||||
"my_agents": my_agents.pagination.total_items,
|
||||
},
|
||||
page=page,
|
||||
page=options.page,
|
||||
more_pages=more_pages,
|
||||
)
|
||||
|
||||
|
||||
111
autogpt_platform/backend/backend/server/v2/library/cache.py
Normal file
111
autogpt_platform/backend/backend/server/v2/library/cache.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Cache functions for Library API endpoints.
|
||||
|
||||
This module contains all caching decorators and helpers for the Library API,
|
||||
separated from the main routes for better organization and maintainability.
|
||||
"""
|
||||
|
||||
import backend.server.v2.library.db
|
||||
from backend.util.cache import cached
|
||||
|
||||
# ===== Library Agent Caches =====
|
||||
|
||||
|
||||
# Cache library agents list for 10 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=600, shared_cache=True)
|
||||
async def get_cached_library_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""Cached helper to get library agents list."""
|
||||
return await backend.server.v2.library.db.list_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache user's favorite agents for 5 minutes - favorites change more frequently
|
||||
@cached(maxsize=500, ttl_seconds=300, shared_cache=True)
|
||||
async def get_cached_library_agent_favorites(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""Cached helper to get user's favorite library agents."""
|
||||
return await backend.server.v2.library.db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual library agent details for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library agent details."""
|
||||
return await backend.server.v2.library.db.get_library_agent(
|
||||
id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# Cache library agent by graph ID for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_agent_by_graph_id(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library agent by graph ID."""
|
||||
return await backend.server.v2.library.db.get_library_agent_by_graph_id(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# Cache library agent by store version ID for 1 hour - marketplace agents are more stable
|
||||
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
|
||||
async def get_cached_library_agent_by_store_version(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library agent by store version ID."""
|
||||
return await backend.server.v2.library.db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# ===== Library Preset Caches =====
|
||||
|
||||
|
||||
# Cache library presets list for 30 minutes
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_presets(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""Cached helper to get library presets list."""
|
||||
return await backend.server.v2.library.db.list_presets(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual preset details for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_preset(
|
||||
preset_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library preset details."""
|
||||
return await backend.server.v2.library.db.get_preset(
|
||||
preset_id=preset_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
286
autogpt_platform/backend/backend/server/v2/library/cache_test.py
Normal file
286
autogpt_platform/backend/backend/server/v2/library/cache_test.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Tests for cache invalidation in Library API routes.
|
||||
|
||||
This module tests that library caches are properly invalidated when data is modified.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as library_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id():
|
||||
"""Generate a mock user ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_library_agent_id():
|
||||
"""Generate a mock library agent ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestLibraryAgentCacheInvalidation:
|
||||
"""Test cache invalidation for library agent operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_agent_clears_list_cache(self, mock_user_id):
|
||||
"""Test that adding an agent clears the library agents list cache."""
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_agents.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
library_db, "list_library_agents", new_callable=AsyncMock
|
||||
) as mock_list:
|
||||
mock_response = {"agents": [], "total_count": 0, "page": 1, "page_size": 20}
|
||||
mock_list.return_value = mock_response
|
||||
|
||||
# First call hits database
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
assert mock_list.call_count == 1
|
||||
|
||||
# Second call uses cache
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
assert mock_list.call_count == 1 # Still 1, cache used
|
||||
|
||||
# Simulate adding an agent (cache invalidation)
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 15
|
||||
)
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 20
|
||||
)
|
||||
|
||||
# Next call should hit database
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
assert mock_list.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_agent_clears_multiple_caches(
|
||||
self, mock_user_id, mock_library_agent_id
|
||||
):
|
||||
"""Test that deleting an agent clears both specific and list caches."""
|
||||
# Clear caches
|
||||
library_cache.get_cached_library_agent.cache_clear()
|
||||
library_cache.get_cached_library_agents.cache_clear()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
library_db, "get_library_agent", new_callable=AsyncMock
|
||||
) as mock_get,
|
||||
patch.object(
|
||||
library_db, "list_library_agents", new_callable=AsyncMock
|
||||
) as mock_list,
|
||||
):
|
||||
mock_agent = {"id": mock_library_agent_id, "name": "Test Agent"}
|
||||
mock_get.return_value = mock_agent
|
||||
mock_list.return_value = {
|
||||
"agents": [mock_agent],
|
||||
"total_count": 1,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
}
|
||||
|
||||
# Populate caches
|
||||
await library_cache.get_cached_library_agent(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
|
||||
initial_calls = {
|
||||
"get": mock_get.call_count,
|
||||
"list": mock_list.call_count,
|
||||
}
|
||||
|
||||
# Verify cache is used
|
||||
await library_cache.get_cached_library_agent(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
|
||||
# Simulate delete_library_agent cache invalidation
|
||||
library_cache.get_cached_library_agent.cache_delete(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 15
|
||||
)
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 20
|
||||
)
|
||||
|
||||
# Next calls should hit database
|
||||
await library_cache.get_cached_library_agent(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_favorites_cache_operations(self, mock_user_id):
|
||||
"""Test that favorites cache works independently."""
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_agent_favorites.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
library_db, "list_favorite_library_agents", new_callable=AsyncMock
|
||||
) as mock_favs:
|
||||
mock_response = {"agents": [], "total_count": 0, "page": 1, "page_size": 20}
|
||||
mock_favs.return_value = mock_response
|
||||
|
||||
# First call hits database
|
||||
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
|
||||
assert mock_favs.call_count == 1
|
||||
|
||||
# Second call uses cache
|
||||
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
|
||||
assert mock_favs.call_count == 1 # Cache used
|
||||
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_agent_favorites.cache_delete(
|
||||
mock_user_id, 1, 20
|
||||
)
|
||||
|
||||
# Next call hits database
|
||||
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
|
||||
assert mock_favs.call_count == 2
|
||||
|
||||
|
||||
class TestLibraryPresetCacheInvalidation:
|
||||
"""Test cache invalidation for library preset operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preset_cache_operations(self, mock_user_id):
|
||||
"""Test preset cache and invalidation."""
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_presets.cache_clear()
|
||||
library_cache.get_cached_library_preset.cache_clear()
|
||||
|
||||
preset_id = str(uuid.uuid4())
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
library_db, "list_presets", new_callable=AsyncMock
|
||||
) as mock_list,
|
||||
patch.object(library_db, "get_preset", new_callable=AsyncMock) as mock_get,
|
||||
):
|
||||
mock_preset = {"id": preset_id, "name": "Test Preset"}
|
||||
mock_list.return_value = {
|
||||
"presets": [mock_preset],
|
||||
"total_count": 1,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
}
|
||||
mock_get.return_value = mock_preset
|
||||
|
||||
# Populate caches
|
||||
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
|
||||
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
|
||||
|
||||
initial_calls = {
|
||||
"list": mock_list.call_count,
|
||||
"get": mock_get.call_count,
|
||||
}
|
||||
|
||||
# Verify cache is used
|
||||
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
|
||||
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
|
||||
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
|
||||
# Clear specific preset cache
|
||||
library_cache.get_cached_library_preset.cache_delete(
|
||||
preset_id, mock_user_id
|
||||
)
|
||||
|
||||
# Clear list cache
|
||||
library_cache.get_cached_library_presets.cache_delete(mock_user_id, 1, 20)
|
||||
|
||||
# Next calls should hit database
|
||||
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
|
||||
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
|
||||
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
|
||||
|
||||
class TestLibraryCacheMetrics:
|
||||
"""Test library cache metrics and management."""
|
||||
|
||||
def test_cache_info_structure(self):
|
||||
"""Test that cache_info returns expected structure."""
|
||||
info = library_cache.get_cached_library_agents.cache_info()
|
||||
|
||||
assert "size" in info
|
||||
assert "maxsize" in info
|
||||
assert "ttl_seconds" in info
|
||||
assert (
|
||||
info["maxsize"] is None
|
||||
) # Redis manages its own size with shared_cache=True
|
||||
assert info["ttl_seconds"] == 600 # 10 minutes
|
||||
|
||||
def test_all_library_caches_can_be_cleared(self):
|
||||
"""Test that all library caches can be cleared."""
|
||||
# Clear all library caches
|
||||
library_cache.get_cached_library_agents.cache_clear()
|
||||
library_cache.get_cached_library_agent_favorites.cache_clear()
|
||||
library_cache.get_cached_library_agent.cache_clear()
|
||||
library_cache.get_cached_library_agent_by_graph_id.cache_clear()
|
||||
library_cache.get_cached_library_agent_by_store_version.cache_clear()
|
||||
library_cache.get_cached_library_presets.cache_clear()
|
||||
library_cache.get_cached_library_preset.cache_clear()
|
||||
|
||||
# Verify all are empty
|
||||
assert library_cache.get_cached_library_agents.cache_info()["size"] == 0
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_favorites.cache_info()["size"] == 0
|
||||
)
|
||||
assert library_cache.get_cached_library_agent.cache_info()["size"] == 0
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_by_graph_id.cache_info()["size"] == 0
|
||||
)
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_by_store_version.cache_info()["size"]
|
||||
== 0
|
||||
)
|
||||
assert library_cache.get_cached_library_presets.cache_info()["size"] == 0
|
||||
assert library_cache.get_cached_library_preset.cache_info()["size"] == 0
|
||||
|
||||
def test_cache_ttl_values(self):
|
||||
"""Test that cache TTL values are set correctly."""
|
||||
# Library agents - 10 minutes
|
||||
assert (
|
||||
library_cache.get_cached_library_agents.cache_info()["ttl_seconds"] == 600
|
||||
)
|
||||
|
||||
# Favorites - 5 minutes (more dynamic)
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_favorites.cache_info()["ttl_seconds"]
|
||||
== 300
|
||||
)
|
||||
|
||||
# Individual agent - 30 minutes
|
||||
assert (
|
||||
library_cache.get_cached_library_agent.cache_info()["ttl_seconds"] == 1800
|
||||
)
|
||||
|
||||
# Presets - 30 minutes
|
||||
assert (
|
||||
library_cache.get_cached_library_presets.cache_info()["ttl_seconds"] == 1800
|
||||
)
|
||||
assert (
|
||||
library_cache.get_cached_library_preset.cache_info()["ttl_seconds"] == 1800
|
||||
)
|
||||
@@ -20,7 +20,7 @@ from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.settings import Config
|
||||
@@ -61,11 +61,11 @@ async def list_library_agents(
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||
raise DatabaseError("Invalid pagination input")
|
||||
raise store_exceptions.DatabaseError("Invalid pagination input")
|
||||
|
||||
if search_term and len(search_term.strip()) > 100:
|
||||
logger.warning(f"Search term too long: {repr(search_term)}")
|
||||
raise DatabaseError("Search term is too long")
|
||||
raise store_exceptions.DatabaseError("Search term is too long")
|
||||
|
||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
@@ -143,7 +143,7 @@ async def list_library_agents(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agents: {e}")
|
||||
raise DatabaseError("Failed to fetch library agents") from e
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agents") from e
|
||||
|
||||
|
||||
async def list_favorite_library_agents(
|
||||
@@ -172,7 +172,7 @@ async def list_favorite_library_agents(
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||
raise DatabaseError("Invalid pagination input")
|
||||
raise store_exceptions.DatabaseError("Invalid pagination input")
|
||||
|
||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
@@ -229,7 +229,9 @@ async def list_favorite_library_agents(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching favorite library agents: {e}")
|
||||
raise DatabaseError("Failed to fetch favorite library agents") from e
|
||||
raise store_exceptions.DatabaseError(
|
||||
"Failed to fetch favorite library agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent:
|
||||
@@ -271,7 +273,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agent: {e}")
|
||||
raise DatabaseError("Failed to fetch library agent") from e
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
|
||||
|
||||
|
||||
async def get_library_agent_by_store_version_id(
|
||||
@@ -336,7 +338,7 @@ async def get_library_agent_by_graph_id(
|
||||
return library_model.LibraryAgent.from_db(agent, sub_graphs=sub_graphs)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agent by graph ID: {e}")
|
||||
raise DatabaseError("Failed to fetch library agent") from e
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
@@ -477,7 +479,9 @@ async def update_agent_version_in_library(
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating agent version in library: {e}")
|
||||
raise DatabaseError("Failed to update agent version in library") from e
|
||||
raise store_exceptions.DatabaseError(
|
||||
"Failed to update agent version in library"
|
||||
) from e
|
||||
|
||||
|
||||
async def update_library_agent(
|
||||
@@ -540,7 +544,7 @@ async def update_library_agent(
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating library agent: {str(e)}")
|
||||
raise DatabaseError("Failed to update library agent") from e
|
||||
raise store_exceptions.DatabaseError("Failed to update library agent") from e
|
||||
|
||||
|
||||
async def delete_library_agent(
|
||||
@@ -568,7 +572,7 @@ async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting library agent: {e}")
|
||||
raise DatabaseError("Failed to delete library agent") from e
|
||||
raise store_exceptions.DatabaseError("Failed to delete library agent") from e
|
||||
|
||||
|
||||
async def add_store_agent_to_library(
|
||||
@@ -659,7 +663,7 @@ async def add_store_agent_to_library(
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error adding agent to library: {e}")
|
||||
raise DatabaseError("Failed to add agent to library") from e
|
||||
raise store_exceptions.DatabaseError("Failed to add agent to library") from e
|
||||
|
||||
|
||||
##############################################
|
||||
@@ -693,7 +697,7 @@ async def list_presets(
|
||||
logger.warning(
|
||||
"Invalid pagination input: page=%d, page_size=%d", page, page_size
|
||||
)
|
||||
raise DatabaseError("Invalid pagination parameters")
|
||||
raise store_exceptions.DatabaseError("Invalid pagination parameters")
|
||||
|
||||
query_filter: prisma.types.AgentPresetWhereInput = {
|
||||
"userId": user_id,
|
||||
@@ -729,7 +733,7 @@ async def list_presets(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error getting presets: {e}")
|
||||
raise DatabaseError("Failed to fetch presets") from e
|
||||
raise store_exceptions.DatabaseError("Failed to fetch presets") from e
|
||||
|
||||
|
||||
async def get_preset(
|
||||
@@ -759,7 +763,7 @@ async def get_preset(
|
||||
return library_model.LibraryAgentPreset.from_db(preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error getting preset: {e}")
|
||||
raise DatabaseError("Failed to fetch preset") from e
|
||||
raise store_exceptions.DatabaseError("Failed to fetch preset") from e
|
||||
|
||||
|
||||
async def create_preset(
|
||||
@@ -809,7 +813,7 @@ async def create_preset(
|
||||
return library_model.LibraryAgentPreset.from_db(new_preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating preset: {e}")
|
||||
raise DatabaseError("Failed to create preset") from e
|
||||
raise store_exceptions.DatabaseError("Failed to create preset") from e
|
||||
|
||||
|
||||
async def create_preset_from_graph_execution(
|
||||
@@ -947,7 +951,7 @@ async def update_preset(
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating preset: {e}")
|
||||
raise DatabaseError("Failed to update preset") from e
|
||||
raise store_exceptions.DatabaseError("Failed to update preset") from e
|
||||
|
||||
|
||||
async def set_preset_webhook(
|
||||
@@ -993,7 +997,7 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting preset: {e}")
|
||||
raise DatabaseError("Failed to delete preset") from e
|
||||
raise store_exceptions.DatabaseError("Failed to delete preset") from e
|
||||
|
||||
|
||||
async def fork_library_agent(
|
||||
@@ -1021,7 +1025,7 @@ async def fork_library_agent(
|
||||
# TODO: once we have open/closed sourced agents this needs to be enabled ~kcze
|
||||
# + update library/agents/[id]/page.tsx agent actions
|
||||
# if not original_agent.can_access_graph:
|
||||
# raise DatabaseError(
|
||||
# raise store_exceptions.DatabaseError(
|
||||
# f"User {user_id} cannot access library agent graph {library_agent_id}"
|
||||
# )
|
||||
|
||||
@@ -1035,4 +1039,4 @@ async def fork_library_agent(
|
||||
return (await create_library_agent(new_graph, user_id))[0]
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error cloning library agent: {e}")
|
||||
raise DatabaseError("Failed to fork library agent") from e
|
||||
raise store_exceptions.DatabaseError("Failed to fork library agent") from e
|
||||
|
||||
@@ -5,10 +5,12 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
import backend.server.cache_config
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,13 +66,22 @@ async def list_library_agents(
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
# Use cache for default queries (no search term, default sort)
|
||||
if search_term is None and sort_by == library_model.LibraryAgentSort.UPDATED_AT:
|
||||
return await library_cache.get_cached_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
else:
|
||||
# Direct DB query for searches and custom sorts
|
||||
return await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
@@ -114,7 +125,7 @@ async def list_favorite_library_agents(
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.list_favorite_library_agents(
|
||||
return await library_cache.get_cached_library_agent_favorites(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@@ -132,7 +143,9 @@ async def get_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
return await library_cache.get_cached_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/by-graph/{graph_id}")
|
||||
@@ -210,18 +223,28 @@ async def add_marketplace_agent_to_library(
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.add_store_agent_to_library(
|
||||
result = await library_db.add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Clear library caches after adding new agent
|
||||
for page in range(1, backend.server.cache_config.MAX_PAGES_TO_CLEAR):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except store_exceptions.AgentNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find store listing version {store_listing_version_id} "
|
||||
"to add to library"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except DatabaseError as e:
|
||||
except store_exceptions.DatabaseError as e:
|
||||
logger.error(f"Database error while adding agent to library: {e}", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@@ -263,19 +286,28 @@ async def update_library_agent(
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.update_library_agent(
|
||||
result = await library_db.update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
)
|
||||
|
||||
for page in range(1, backend.server.cache_config.MAX_PAGES_TO_CLEAR):
|
||||
library_cache.get_cached_library_agent_favorites.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return result
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except DatabaseError as e:
|
||||
except store_exceptions.DatabaseError as e:
|
||||
logger.error(f"Database error while updating library agent: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@@ -320,6 +352,18 @@ async def delete_library_agent(
|
||||
await library_db.delete_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Clear caches after deleting agent
|
||||
library_cache.get_cached_library_agent.cache_delete(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
for page in range(1, backend.server.cache_config.MAX_PAGES_TO_CLEAR):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_LIBRARY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -4,6 +4,9 @@ from typing import Any, Optional
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
|
||||
import backend.server.cache_config
|
||||
import backend.server.routers.cache as cache
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
@@ -25,6 +28,24 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _clear_presets_list_cache(
|
||||
user_id: str, num_pages: int = backend.server.cache_config.MAX_PAGES_TO_CLEAR
|
||||
):
|
||||
"""
|
||||
Clear the presets list cache for the given user.
|
||||
Clears both primary and alternative page sizes for backward compatibility.
|
||||
"""
|
||||
page_sizes = backend.server.cache_config.get_page_sizes_for_clearing(
|
||||
backend.server.cache_config.V2_LIBRARY_PRESETS_PAGE_SIZE,
|
||||
backend.server.cache_config.V2_LIBRARY_PRESETS_ALT_PAGE_SIZE,
|
||||
)
|
||||
for page in range(1, num_pages + 1):
|
||||
for page_size in page_sizes:
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/presets",
|
||||
summary="List presets",
|
||||
@@ -51,12 +72,21 @@ async def list_presets(
|
||||
models.LibraryAgentPresetResponse: A response containing the list of presets.
|
||||
"""
|
||||
try:
|
||||
return await db.list_presets(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
# Use cache only for default queries (no filter)
|
||||
if graph_id is None:
|
||||
return await library_cache.get_cached_library_presets(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
else:
|
||||
# Direct DB query for filtered requests
|
||||
return await db.list_presets(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list presets for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
@@ -87,7 +117,7 @@ async def get_preset(
|
||||
HTTPException: If the preset is not found or an error occurs.
|
||||
"""
|
||||
try:
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
preset = await library_cache.get_cached_library_preset(preset_id, user_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
|
||||
@@ -131,9 +161,13 @@ async def create_preset(
|
||||
"""
|
||||
try:
|
||||
if isinstance(preset, models.LibraryAgentPresetCreatable):
|
||||
return await db.create_preset(user_id, preset)
|
||||
result = await db.create_preset(user_id, preset)
|
||||
else:
|
||||
return await db.create_preset_from_graph_execution(user_id, preset)
|
||||
result = await db.create_preset_from_graph_execution(user_id, preset)
|
||||
|
||||
_clear_presets_list_cache(user_id)
|
||||
|
||||
return result
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except Exception as e:
|
||||
@@ -200,6 +234,9 @@ async def setup_trigger(
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
|
||||
_clear_presets_list_cache(user_id)
|
||||
|
||||
return new_preset
|
||||
|
||||
|
||||
@@ -278,6 +315,13 @@ async def update_preset(
|
||||
description=preset.description,
|
||||
is_active=preset.is_active,
|
||||
)
|
||||
|
||||
# Clear caches after updating preset
|
||||
library_cache.get_cached_library_preset.cache_delete(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
_clear_presets_list_cache(user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Preset update failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
@@ -351,6 +395,12 @@ async def delete_preset(
|
||||
|
||||
try:
|
||||
await db.delete_preset(user_id, preset_id)
|
||||
|
||||
# Clear caches after deleting preset
|
||||
library_cache.get_cached_library_preset.cache_delete(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
_clear_presets_list_cache(user_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error deleting preset %s for user %s: %s", preset_id, user_id, e
|
||||
@@ -401,6 +451,33 @@ async def execute_preset(
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_credential_inputs = preset.credentials | credential_inputs
|
||||
|
||||
# Clear graph executions cache - use both page sizes for compatibility
|
||||
for page in range(1, 10):
|
||||
# Clear with alternative page size
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
graph_id=preset.graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_GRAPH_EXECUTIONS_ALT_PAGE_SIZE,
|
||||
)
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_GRAPH_EXECUTIONS_ALT_PAGE_SIZE,
|
||||
)
|
||||
# Clear with v1 page size (25)
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
graph_id=preset.graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE,
|
||||
)
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V1_GRAPH_EXECUTIONS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
return await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
|
||||
@@ -179,14 +179,15 @@ async def test_get_favorite_library_agents_success(
|
||||
def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
# Mock the cache function instead of the DB directly since routes now use cache
|
||||
mock_cache_call = mocker.patch(
|
||||
"backend.server.v2.library.routes.agents.library_cache.get_cached_library_agent_favorites"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
mock_cache_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 500
|
||||
mock_db_call.assert_called_once_with(
|
||||
mock_cache_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
|
||||
@@ -1,22 +1,61 @@
|
||||
"""
|
||||
Cache functions for Store API endpoints.
|
||||
|
||||
This module contains all caching decorators and helpers for the Store API,
|
||||
separated from the main routes for better organization and maintainability.
|
||||
"""
|
||||
|
||||
import backend.server.cache_config
|
||||
import backend.server.v2.store.db
|
||||
from backend.util.cache import cached
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
##############################################
|
||||
|
||||
def _clear_submissions_cache(
|
||||
user_id: str, num_pages: int = backend.server.cache_config.MAX_PAGES_TO_CLEAR
|
||||
):
|
||||
"""
|
||||
Clear the submissions cache for the given user.
|
||||
|
||||
Args:
|
||||
user_id: User ID whose cache should be cleared
|
||||
num_pages: Number of pages to clear (default from cache_config)
|
||||
"""
|
||||
for page in range(1, num_pages + 1):
|
||||
_get_cached_submissions.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
|
||||
def clear_all_caches():
|
||||
"""Clear all caches."""
|
||||
_get_cached_store_agents.cache_clear()
|
||||
_get_cached_agent_details.cache_clear()
|
||||
_get_cached_store_creators.cache_clear()
|
||||
_get_cached_creator_details.cache_clear()
|
||||
def _clear_my_agents_cache(
|
||||
user_id: str, num_pages: int = backend.server.cache_config.MAX_PAGES_TO_CLEAR
|
||||
):
|
||||
"""
|
||||
Clear the my agents cache for the given user.
|
||||
|
||||
Args:
|
||||
user_id: User ID whose cache should be cleared
|
||||
num_pages: Number of pages to clear (default from cache_config)
|
||||
"""
|
||||
for page in range(1, num_pages + 1):
|
||||
_get_cached_my_agents.cache_delete(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=backend.server.cache_config.V2_MY_AGENTS_PAGE_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# Cache store agents list for 5 minutes
|
||||
# Cache user profiles for 1 hour per user
|
||||
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_user_profile(user_id: str):
|
||||
"""Cached helper to get user profile."""
|
||||
return await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
|
||||
|
||||
# Cache store agents list for 15 minutes
|
||||
# Different cache entries for different query combinations
|
||||
@cached(maxsize=5000, ttl_seconds=300, shared_cache=True)
|
||||
@cached(maxsize=5000, ttl_seconds=900, shared_cache=True)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
@@ -39,7 +78,7 @@ async def _get_cached_store_agents(
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||
@cached(maxsize=200, ttl_seconds=900, shared_cache=True)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
@@ -47,8 +86,26 @@ async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
)
|
||||
|
||||
|
||||
# Cache creators list for 5 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||
# Cache agent graphs for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_agent_graph(store_listing_version_id: str):
|
||||
"""Cached helper to get agent graph."""
|
||||
return await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache agent by version for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
|
||||
"""Cached helper to get store agent by version ID."""
|
||||
return await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache creators list for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
@@ -66,10 +123,30 @@ async def _get_cached_store_creators(
|
||||
)
|
||||
|
||||
|
||||
# Cache individual creator details for 5 minutes
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
# Cache individual creator details for 1 hour
|
||||
@cached(maxsize=100, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
|
||||
|
||||
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
|
||||
@cached(maxsize=500, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's agents."""
|
||||
return await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
|
||||
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's submissions."""
|
||||
return await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
@@ -26,7 +25,6 @@ from backend.data.notifications import (
|
||||
NotificationEventModel,
|
||||
)
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -72,199 +70,65 @@ async def get_store_agents(
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
sanitized_query = sanitize_query(search_query)
|
||||
|
||||
sanitized_creators = []
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
for c in creators:
|
||||
sanitized_creators.append(sanitize_query(c))
|
||||
|
||||
sanitized_category = None
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
sanitized_category = sanitize_query(category)
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
if sanitized_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == "runs":
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
try:
|
||||
# If search_query is provided, use full-text search
|
||||
if search_query:
|
||||
search_term = sanitize_query(search_query)
|
||||
if not search_term:
|
||||
# Return empty results for invalid search query
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=page_size,
|
||||
),
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
)
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Whitelist allowed order_by columns
|
||||
ALLOWED_ORDER_BY = {
|
||||
"rating": "rating DESC, rank DESC",
|
||||
"runs": "runs DESC, rank DESC",
|
||||
"name": "agent_name ASC, rank DESC",
|
||||
"updated_at": "updated_at DESC, rank DESC",
|
||||
}
|
||||
|
||||
# Validate and get order clause
|
||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||
else:
|
||||
order_by_clause = "updated_at DESC, rank DESC",
|
||||
|
||||
# Build WHERE conditions and parameters list
|
||||
where_parts: list[str] = []
|
||||
params: list[typing.Any] = [search_term] # $1 - search term
|
||||
param_index = 2 # Start at $2 for next parameter
|
||||
|
||||
# Always filter for available agents
|
||||
where_parts.append("is_available = true")
|
||||
|
||||
if featured:
|
||||
where_parts.append("featured = true")
|
||||
|
||||
if creators and sanitized_creators:
|
||||
# Use ANY with array parameter
|
||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||
params.append(sanitized_creators)
|
||||
param_index += 1
|
||||
|
||||
if category and sanitized_category:
|
||||
where_parts.append(f"${param_index} = ANY(categories)")
|
||||
params.append(sanitized_category)
|
||||
param_index += 1
|
||||
|
||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
limit_param = f"${param_index}"
|
||||
offset_param = f"${param_index + 1}"
|
||||
|
||||
# Execute full-text search query with parameterized values
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
ts_rank_cd(search, query) AS rank
|
||||
FROM "StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Count query for pagination - only uses search term parameter
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
"""
|
||||
|
||||
# Execute both queries with parameters
|
||||
agents = await prisma.client.get_client().query_raw(
|
||||
typing.cast(typing.LiteralString, sql_query), *params
|
||||
)
|
||||
|
||||
# For count, use params without pagination (last 2 params)
|
||||
count_params = params[:-2]
|
||||
count_result = await prisma.client.get_client().query_raw(
|
||||
typing.cast(typing.LiteralString, count_query), *count_params
|
||||
)
|
||||
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Store agent from search results: {e}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# Non-search query path (original logic)
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": sanitized_creators}
|
||||
if sanitized_category:
|
||||
where_clause["categories"] = {"has": sanitized_category}
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == "runs":
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
)
|
||||
# Add to the list only if creation was successful
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
# You could log the error here if needed
|
||||
logger.error(
|
||||
f"Error parsing Store agent when getting store agents from db: {e}"
|
||||
)
|
||||
continue
|
||||
# Add to the list only if creation was successful
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
# You could log the error here if needed
|
||||
logger.error(
|
||||
f"Error parsing Store agent when getting store agents from db: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
@@ -278,25 +142,9 @@ async def get_store_agents(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agents: {e}")
|
||||
raise DatabaseError("Failed to fetch store agents") from e
|
||||
# TODO: commenting this out as we concerned about potential db load issues
|
||||
# finally:
|
||||
# if search_term:
|
||||
# await log_search_term(search_query=search_term)
|
||||
|
||||
|
||||
async def log_search_term(search_query: str):
|
||||
"""Log a search term to the database"""
|
||||
|
||||
# Anonymize the data by preventing correlation with other logs
|
||||
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
try:
|
||||
await prisma.models.SearchTerms.prisma().create(
|
||||
data={"searchTerm": search_query, "createdDate": date}
|
||||
)
|
||||
except Exception as e:
|
||||
# Fail silently here so that logging search terms doesn't break the app
|
||||
logger.error(f"Error logging search term: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch store agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_agent_details(
|
||||
@@ -389,7 +237,9 @@ async def get_store_agent_details(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agent details: {e}")
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent details"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
@@ -416,7 +266,9 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent: {e}")
|
||||
raise DatabaseError("Failed to fetch agent") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_agent_by_version_id(
|
||||
@@ -456,7 +308,9 @@ async def get_store_agent_by_version_id(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agent details: {e}")
|
||||
raise DatabaseError("Failed to fetch agent details") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent details"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_creators(
|
||||
@@ -482,7 +336,9 @@ async def get_store_creators(
|
||||
# Sanitize and validate search query by escaping special characters
|
||||
sanitized_query = search_query.strip()
|
||||
if not sanitized_query or len(sanitized_query) > 100: # Reasonable length limit
|
||||
raise DatabaseError("Invalid search query")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Invalid search query"
|
||||
)
|
||||
|
||||
# Escape special SQL characters
|
||||
sanitized_query = (
|
||||
@@ -508,9 +364,11 @@ async def get_store_creators(
|
||||
try:
|
||||
# Validate pagination parameters
|
||||
if not isinstance(page, int) or page < 1:
|
||||
raise DatabaseError("Invalid page number")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Invalid page number"
|
||||
)
|
||||
if not isinstance(page_size, int) or page_size < 1 or page_size > 100:
|
||||
raise DatabaseError("Invalid page size")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError("Invalid page size")
|
||||
|
||||
# Get total count for pagination using sanitized where clause
|
||||
total = await prisma.models.Creator.prisma().count(
|
||||
@@ -565,7 +423,9 @@ async def get_store_creators(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store creators: {e}")
|
||||
raise DatabaseError("Failed to fetch store creators") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch store creators"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_creator_details(
|
||||
@@ -600,7 +460,9 @@ async def get_store_creator_details(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store creator details: {e}")
|
||||
raise DatabaseError("Failed to fetch creator details") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch creator details"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_submissions(
|
||||
@@ -631,6 +493,7 @@ async def get_store_submissions(
|
||||
submission_models = []
|
||||
for sub in submissions:
|
||||
submission_model = backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=sub.user_id,
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
@@ -848,6 +711,7 @@ async def create_store_submission(
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
# Return submission details
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -863,21 +727,7 @@ async def create_store_submission(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
changes_summary=changes_summary,
|
||||
)
|
||||
except prisma.errors.UniqueViolationError as exc:
|
||||
# Attempt to check if the error was due to the slug field being unique
|
||||
error_str = str(exc)
|
||||
if "slug" in error_str.lower():
|
||||
logger.debug(
|
||||
f"Slug '{slug}' is already in use by another agent (agent_id: {agent_id}) for user {user_id}"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.SlugAlreadyInUseError(
|
||||
f"The URL slug '{slug}' is already in use by another one of your agents. Please choose a different slug."
|
||||
) from exc
|
||||
else:
|
||||
# Reraise as a generic database error for other unique violations
|
||||
raise DatabaseError(
|
||||
f"Unique constraint violated (not slug): {error_str}"
|
||||
) from exc
|
||||
|
||||
except (
|
||||
backend.server.v2.store.exceptions.AgentNotFoundError,
|
||||
backend.server.v2.store.exceptions.ListingExistsError,
|
||||
@@ -885,7 +735,9 @@ async def create_store_submission(
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating store submission: {e}")
|
||||
raise DatabaseError("Failed to create store submission") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store submission"
|
||||
) from e
|
||||
|
||||
|
||||
async def edit_store_submission(
|
||||
@@ -1006,8 +858,11 @@ async def edit_store_submission(
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update store listing version"
|
||||
)
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=user_id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
name=name,
|
||||
@@ -1042,7 +897,9 @@ async def edit_store_submission(
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error editing store submission: {e}")
|
||||
raise DatabaseError("Failed to edit store submission") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to edit store submission"
|
||||
) from e
|
||||
|
||||
|
||||
async def create_store_version(
|
||||
@@ -1139,6 +996,7 @@ async def create_store_version(
|
||||
)
|
||||
# Return submission details
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -1157,7 +1015,9 @@ async def create_store_version(
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise DatabaseError("Failed to create new store version") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create new store version"
|
||||
) from e
|
||||
|
||||
|
||||
async def create_store_review(
|
||||
@@ -1197,7 +1057,9 @@ async def create_store_review(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating store review: {e}")
|
||||
raise DatabaseError("Failed to create store review") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store review"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_profile(
|
||||
@@ -1221,7 +1083,9 @@ async def get_user_profile(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user profile: {e}")
|
||||
raise DatabaseError("Failed to get user profile") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to get user profile"
|
||||
) from e
|
||||
|
||||
|
||||
async def update_profile(
|
||||
@@ -1258,7 +1122,7 @@ async def update_profile(
|
||||
logger.error(
|
||||
f"Unauthorized update attempt for profile {existing_profile.id} by user {user_id}"
|
||||
)
|
||||
raise DatabaseError(
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
f"Unauthorized update attempt for profile {existing_profile.id} by user {user_id}"
|
||||
)
|
||||
|
||||
@@ -1283,7 +1147,9 @@ async def update_profile(
|
||||
)
|
||||
if updated_profile is None:
|
||||
logger.error(f"Failed to update profile for user {user_id}")
|
||||
raise DatabaseError("Failed to update profile")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update profile"
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
name=updated_profile.name,
|
||||
@@ -1298,7 +1164,9 @@ async def update_profile(
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating profile: {e}")
|
||||
raise DatabaseError("Failed to update profile") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update profile"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_my_agents(
|
||||
@@ -1366,7 +1234,9 @@ async def get_my_agents(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting my agents: {e}")
|
||||
raise DatabaseError("Failed to fetch my agents") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch my agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_agent(store_listing_version_id: str) -> GraphModel:
|
||||
@@ -1627,8 +1497,8 @@ async def review_store_submission(
|
||||
include={"StoreListing": True},
|
||||
)
|
||||
|
||||
if not submission:
|
||||
raise DatabaseError(
|
||||
if not submission or not submission.StoreListing:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
f"Failed to update store listing version {store_listing_version_id}"
|
||||
)
|
||||
|
||||
@@ -1717,6 +1587,7 @@ async def review_store_submission(
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=submission.StoreListing.owningUserId,
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
name=submission.name,
|
||||
@@ -1743,7 +1614,9 @@ async def review_store_submission(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Could not create store submission review: {e}")
|
||||
raise DatabaseError("Failed to create store submission review") from e
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store submission review"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_admin_listings_with_versions(
|
||||
@@ -1847,14 +1720,17 @@ async def get_admin_listings_with_versions(
|
||||
# Get total count for pagination
|
||||
total = await prisma.models.StoreListing.prisma().count(where=where)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert to response models
|
||||
listings_with_versions = []
|
||||
for listing in listings:
|
||||
versions: list[backend.server.v2.store.model.StoreSubmission] = []
|
||||
if not listing.OwningUser:
|
||||
logger.error(f"Listing {listing.id} has no owning user")
|
||||
continue
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = backend.server.v2.store.model.StoreSubmission(
|
||||
user_id=listing.OwningUser.id,
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
name=version.name,
|
||||
@@ -1922,27 +1798,6 @@ async def get_admin_listings_with_versions(
|
||||
)
|
||||
|
||||
|
||||
async def check_submission_already_approved(
|
||||
store_listing_version_id: str,
|
||||
) -> bool:
|
||||
"""Check the submission status of a store listing version."""
|
||||
try:
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}
|
||||
)
|
||||
)
|
||||
if not store_listing_version:
|
||||
return False
|
||||
return (
|
||||
store_listing_version.submissionStatus
|
||||
== prisma.enums.SubmissionStatus.APPROVED
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking submission status: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_agent_as_admin(
|
||||
user_id: str | None,
|
||||
store_listing_version_id: str,
|
||||
|
||||
@@ -20,7 +20,7 @@ async def setup_prisma():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_agents(mocker):
|
||||
# Mock data
|
||||
mock_agents = [
|
||||
@@ -42,7 +42,6 @@ async def test_get_store_agents(mocker):
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -64,7 +63,7 @@ async def test_get_store_agents(mocker):
|
||||
mock_store_agent.return_value.count.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
@@ -85,7 +84,6 @@ async def test_get_store_agent_details(mocker):
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=False,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
@@ -107,7 +105,6 @@ async def test_get_store_agent_details(mocker):
|
||||
versions=["1.0", "2.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
@@ -173,7 +170,7 @@ async def test_get_store_agent_details(mocker):
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_creator_details(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
@@ -210,7 +207,7 @@ async def test_get_store_creator_details(mocker):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_store_submission(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
@@ -251,7 +248,6 @@ async def test_create_store_submission(mocker):
|
||||
isAvailable=True,
|
||||
)
|
||||
],
|
||||
useForOnboarding=False,
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
@@ -279,10 +275,11 @@ async def test_create_store_submission(mocker):
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
@@ -327,7 +324,7 @@ async def test_update_profile(mocker):
|
||||
mock_profile_db.return_value.update.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
@@ -359,63 +356,3 @@ async def test_get_user_profile(mocker):
|
||||
assert result.description == "Test description"
|
||||
assert result.links == ["link1", "link2"]
|
||||
assert result.avatar_url == "avatar.jpg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_with_search_parameterized(mocker):
|
||||
"""Test that search query uses parameterized SQL - validates the fix works"""
|
||||
|
||||
# Call function with search query containing potential SQL injection
|
||||
malicious_search = "test'; DROP TABLE StoreAgent; --"
|
||||
result = await db.get_store_agents(search_query=malicious_search)
|
||||
|
||||
# Verify query executed safely
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_with_search_and_filters_parameterized():
|
||||
"""Test parameterized SQL with multiple filters"""
|
||||
|
||||
# Call with multiple filters including potential injection attempts
|
||||
result = await db.get_store_agents(
|
||||
search_query="test",
|
||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||
category="AI'; DELETE FROM StoreAgent; --",
|
||||
featured=True,
|
||||
sorted_by="rating",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the query executed without error
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_search_with_invalid_sort_by():
|
||||
"""Test that invalid sorted_by value doesn't cause SQL injection""" # Try to inject SQL via sorted_by parameter
|
||||
malicious_sort = "rating; DROP TABLE Users; --"
|
||||
result = await db.get_store_agents(
|
||||
search_query="test",
|
||||
sorted_by=malicious_sort,
|
||||
)
|
||||
|
||||
# Verify the query executed without error
|
||||
# Invalid sort_by should fall back to default, not cause SQL injection
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_search_category_array_injection():
|
||||
"""Test that category parameter is safely passed as a parameter"""
|
||||
# Try SQL injection via category
|
||||
malicious_category = "AI'; DROP TABLE StoreAgent; --"
|
||||
result = await db.get_store_agents(
|
||||
search_query="test",
|
||||
category=malicious_category,
|
||||
)
|
||||
|
||||
# Verify the query executed without error
|
||||
# Category should be parameterized, preventing SQL injection
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
|
||||
class MediaUploadError(ValueError):
|
||||
class MediaUploadError(Exception):
|
||||
"""Base exception for media upload errors"""
|
||||
|
||||
pass
|
||||
@@ -51,19 +48,19 @@ class VirusScanError(MediaUploadError):
|
||||
pass
|
||||
|
||||
|
||||
class StoreError(ValueError):
|
||||
class StoreError(Exception):
|
||||
"""Base exception for store-related errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentNotFoundError(NotFoundError):
|
||||
class AgentNotFoundError(StoreError):
|
||||
"""Raised when an agent is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreatorNotFoundError(NotFoundError):
|
||||
class CreatorNotFoundError(StoreError):
|
||||
"""Raised when a creator is not found"""
|
||||
|
||||
pass
|
||||
@@ -75,19 +72,25 @@ class ListingExistsError(StoreError):
|
||||
pass
|
||||
|
||||
|
||||
class ProfileNotFoundError(NotFoundError):
|
||||
class DatabaseError(StoreError):
|
||||
"""Raised when there is an error interacting with the database"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProfileNotFoundError(StoreError):
|
||||
"""Raised when a profile is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ListingNotFoundError(NotFoundError):
|
||||
class ListingNotFoundError(StoreError):
|
||||
"""Raised when a store listing is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubmissionNotFoundError(NotFoundError):
|
||||
class SubmissionNotFoundError(StoreError):
|
||||
"""Raised when a submission is not found"""
|
||||
|
||||
pass
|
||||
@@ -103,9 +106,3 @@ class UnauthorizedError(StoreError):
|
||||
"""Raised when a user is not authorized to perform an action"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SlugAlreadyInUseError(StoreError):
|
||||
"""Raised when a slug is already in use by another agent owned by the user"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -98,6 +98,7 @@ class Profile(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
user_id: str = pydantic.Field(default="", exclude=True)
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
|
||||
@@ -135,6 +135,7 @@ def test_creator_details():
|
||||
|
||||
def test_store_submission():
|
||||
submission = backend.server.v2.store.model.StoreSubmission(
|
||||
user_id="user123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
@@ -156,6 +157,7 @@ def test_store_submissions_response():
|
||||
response = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
user_id="user123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
|
||||
@@ -8,13 +8,25 @@ import fastapi
|
||||
import fastapi.responses
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.cache as store_cache
|
||||
import backend.server.cache_config
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.image_gen
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
from backend.server.v2.store.cache import (
|
||||
_clear_submissions_cache,
|
||||
_get_cached_agent_details,
|
||||
_get_cached_agent_graph,
|
||||
_get_cached_creator_details,
|
||||
_get_cached_my_agents,
|
||||
_get_cached_store_agent_by_version,
|
||||
_get_cached_store_agents,
|
||||
_get_cached_store_creators,
|
||||
_get_cached_submissions,
|
||||
_get_cached_user_profile,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,13 +52,23 @@ async def get_profile(
|
||||
Get the profile details for the authenticated user.
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
profile = await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
try:
|
||||
profile = await _get_cached_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "Profile not found"},
|
||||
)
|
||||
return profile
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch user profile for %s: %s", user_id, e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "Profile not found"},
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "Failed to retrieve user profile",
|
||||
"hint": "Check database connection.",
|
||||
},
|
||||
)
|
||||
return profile
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -73,10 +95,22 @@ async def update_or_create_profile(
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
updated_profile = await backend.server.v2.store.db.update_profile(
|
||||
user_id=user_id, profile=profile
|
||||
)
|
||||
return updated_profile
|
||||
try:
|
||||
updated_profile = await backend.server.v2.store.db.update_profile(
|
||||
user_id=user_id, profile=profile
|
||||
)
|
||||
# Clear the cache for this user after profile update
|
||||
_get_cached_user_profile.cache_delete(user_id)
|
||||
return updated_profile
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update profile for user %s: %s", user_id, e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "Failed to update user profile",
|
||||
"hint": "Validate request data.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
@@ -97,10 +131,11 @@ async def get_agents(
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
page_size: int = backend.server.cache_config.V2_STORE_AGENTS_PAGE_SIZE,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
Results are cached for 15 minutes.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
@@ -135,16 +170,26 @@ async def get_agents(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
try:
|
||||
agents = await _get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
except Exception as e:
|
||||
logger.exception("Failed to retrieve store agents: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "Failed to retrieve store agents",
|
||||
"hint": "Check database or search parameters.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -156,16 +201,26 @@ async def get_agents(
|
||||
async def get_agent(username: str, agent_name: str):
|
||||
"""
|
||||
This is only used on the AgentDetails Page.
|
||||
Results are cached for 15 minutes.
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
agent = await store_cache._get_cached_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
agent = await _get_cached_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store agent details")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving the store agent details"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -177,11 +232,17 @@ async def get_agent(username: str, agent_name: str):
|
||||
async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: str):
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
Results are cached for 1 hour.
|
||||
"""
|
||||
graph = await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
return graph
|
||||
try:
|
||||
graph = await _get_cached_agent_graph(store_listing_version_id)
|
||||
return graph
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting agent graph")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the agent graph"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -194,12 +255,17 @@ async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: s
|
||||
async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
Results are cached for 1 hour.
|
||||
"""
|
||||
agent = await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
return agent
|
||||
try:
|
||||
agent = await _get_cached_store_agent_by_version(store_listing_version_id)
|
||||
return agent
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store agent")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the store agent"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -227,17 +293,24 @@ async def create_review(
|
||||
Returns:
|
||||
The created review
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
created_review = await backend.server.v2.store.db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
created_review = await backend.server.v2.store.db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
return created_review
|
||||
return created_review
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store review")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while creating the store review"},
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
@@ -256,13 +329,15 @@ async def get_creators(
|
||||
search_query: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
page_size: int = backend.server.cache_config.V2_STORE_CREATORS_PAGE_SIZE,
|
||||
):
|
||||
"""
|
||||
This is needed for:
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
Results are cached for 1 hour.
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
@@ -280,14 +355,21 @@ async def get_creators(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return creators
|
||||
try:
|
||||
creators = await _get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return creators
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store creators")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the store creators"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -301,11 +383,21 @@ async def get_creator(
|
||||
):
|
||||
"""
|
||||
Get the details of a creator.
|
||||
Results are cached for 1 hour.
|
||||
- Creator Details Page
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await _get_cached_creator_details(username=username)
|
||||
return creator
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting creator details")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving the creator details"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
############################################
|
||||
@@ -323,15 +415,23 @@ async def get_creator(
|
||||
async def get_my_agents(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
||||
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
||||
page_size: typing.Annotated[
|
||||
int, fastapi.Query(ge=1)
|
||||
] = backend.server.cache_config.V2_MY_AGENTS_PAGE_SIZE,
|
||||
):
|
||||
"""
|
||||
Get user's own agents.
|
||||
Results are cached for 5 minutes per user.
|
||||
"""
|
||||
agents = await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
return agents
|
||||
try:
|
||||
agents = await _get_cached_my_agents(user_id, page=page, page_size=page_size)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting my agents")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the my agents"},
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -355,12 +455,23 @@ async def delete_submission(
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
result = await backend.server.v2.store.db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
try:
|
||||
result = await backend.server.v2.store.db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
return result
|
||||
# Clear submissions cache for this specific user after deletion
|
||||
if result:
|
||||
_clear_submissions_cache(user_id)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst deleting store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while deleting the store submission"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -373,10 +484,11 @@ async def delete_submission(
|
||||
async def get_submissions(
|
||||
user_id: str = fastapi.Security(autogpt_libs.auth.get_user_id),
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
page_size: int = backend.server.cache_config.V2_STORE_SUBMISSIONS_PAGE_SIZE,
|
||||
):
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
Results are cached for 1 hour per user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
@@ -398,12 +510,19 @@ async def get_submissions(
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
listings = await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
try:
|
||||
listings = await _get_cached_submissions(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
return listings
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store submissions")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving the store submissions"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -430,23 +549,32 @@ async def create_submission(
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
result = await backend.server.v2.store.db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
instructions=submission_request.instructions,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
try:
|
||||
result = await backend.server.v2.store.db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
instructions=submission_request.instructions,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
_clear_submissions_cache(user_id)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while creating the store submission"},
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -489,6 +617,8 @@ async def edit_submission(
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
_clear_submissions_cache(user_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -515,10 +645,36 @@ async def upload_submission_media(
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
media_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=file
|
||||
)
|
||||
return media_url
|
||||
try:
|
||||
media_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=file
|
||||
)
|
||||
return media_url
|
||||
except backend.server.v2.store.exceptions.VirusDetectedError as e:
|
||||
logger.warning(f"Virus detected in uploaded file: {e.threat_name}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"detail": f"File rejected due to virus detection: {e.threat_name}",
|
||||
"error_type": "virus_detected",
|
||||
"threat_name": e.threat_name,
|
||||
},
|
||||
)
|
||||
except backend.server.v2.store.exceptions.VirusScanError as e:
|
||||
logger.error(f"Virus scanning failed: {str(e)}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"detail": "Virus scanning service unavailable. Please try again later.",
|
||||
"error_type": "virus_scan_failed",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst uploading submission media")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while uploading the media file"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -541,35 +697,44 @@ async def generate_image(
|
||||
Returns:
|
||||
JSONResponse: JSON containing the URL of the generated image
|
||||
"""
|
||||
agent = await backend.data.graph.get_graph(agent_id, user_id=user_id)
|
||||
try:
|
||||
agent = await backend.data.graph.get_graph(agent_id, user_id=user_id)
|
||||
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
|
||||
existing_url = await backend.server.v2.store.media.check_media_exists(
|
||||
user_id, filename
|
||||
)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
# Generate agent image as JPEG
|
||||
image = await backend.server.v2.store.image_gen.generate_agent_image(
|
||||
agent=agent
|
||||
)
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
|
||||
existing_url = await backend.server.v2.store.media.check_media_exists(
|
||||
user_id, filename
|
||||
)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
# Generate agent image as JPEG
|
||||
image = await backend.server.v2.store.image_gen.generate_agent_image(agent=agent)
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
image_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
image_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst generating submission image")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while generating the image"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -646,10 +811,15 @@ async def get_cache_metrics():
|
||||
)
|
||||
|
||||
# Add metrics for each cache
|
||||
add_cache_metrics("store_agents", store_cache._get_cached_store_agents)
|
||||
add_cache_metrics("agent_details", store_cache._get_cached_agent_details)
|
||||
add_cache_metrics("store_creators", store_cache._get_cached_store_creators)
|
||||
add_cache_metrics("creator_details", store_cache._get_cached_creator_details)
|
||||
add_cache_metrics("user_profile", _get_cached_user_profile)
|
||||
add_cache_metrics("store_agents", _get_cached_store_agents)
|
||||
add_cache_metrics("agent_details", _get_cached_agent_details)
|
||||
add_cache_metrics("agent_graph", _get_cached_agent_graph)
|
||||
add_cache_metrics("agent_by_version", _get_cached_store_agent_by_version)
|
||||
add_cache_metrics("store_creators", _get_cached_store_creators)
|
||||
add_cache_metrics("creator_details", _get_cached_creator_details)
|
||||
add_cache_metrics("my_agents", _get_cached_my_agents)
|
||||
add_cache_metrics("submissions", _get_cached_submissions)
|
||||
|
||||
# Add metadata/help text at the beginning
|
||||
prometheus_output = [
|
||||
|
||||
@@ -534,6 +534,7 @@ def test_get_submissions_success(
|
||||
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
user_id="user123",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
|
||||
@@ -4,12 +4,18 @@ Test suite for verifying cache_delete functionality in store routes.
|
||||
Tests that specific cache entries can be deleted while preserving others.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.store import cache as store_cache
|
||||
from backend.server.v2.store.model import StoreAgent, StoreAgentsResponse
|
||||
from backend.server.v2.store import routes
|
||||
from backend.server.v2.store.model import (
|
||||
ProfileDetails,
|
||||
StoreAgent,
|
||||
StoreAgentDetails,
|
||||
StoreAgentsResponse,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
|
||||
@@ -48,10 +54,10 @@ class TestCacheDeletion:
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
# Clear cache first
|
||||
store_cache._get_cached_store_agents.cache_clear()
|
||||
routes._get_cached_store_agents.cache_clear()
|
||||
|
||||
# First call - should hit database
|
||||
result1 = await store_cache._get_cached_store_agents(
|
||||
result1 = await routes._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -64,7 +70,7 @@ class TestCacheDeletion:
|
||||
assert result1.agents[0].agent_name == "Test Agent"
|
||||
|
||||
# Second call with same params - should use cache
|
||||
await store_cache._get_cached_store_agents(
|
||||
await routes._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -76,7 +82,7 @@ class TestCacheDeletion:
|
||||
assert mock_db.call_count == 1 # No additional DB call
|
||||
|
||||
# Third call with different params - should hit database
|
||||
await store_cache._get_cached_store_agents(
|
||||
await routes._get_cached_store_agents(
|
||||
featured=True, # Different param
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -88,7 +94,7 @@ class TestCacheDeletion:
|
||||
assert mock_db.call_count == 2 # New DB call
|
||||
|
||||
# Delete specific cache entry
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -100,7 +106,7 @@ class TestCacheDeletion:
|
||||
assert deleted is True # Entry was deleted
|
||||
|
||||
# Try to delete non-existent entry
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=False,
|
||||
creator="nonexistent",
|
||||
sorted_by=None,
|
||||
@@ -112,7 +118,7 @@ class TestCacheDeletion:
|
||||
assert deleted is False # Entry didn't exist
|
||||
|
||||
# Call with deleted params - should hit database again
|
||||
await store_cache._get_cached_store_agents(
|
||||
await routes._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -124,7 +130,7 @@ class TestCacheDeletion:
|
||||
assert mock_db.call_count == 3 # New DB call after deletion
|
||||
|
||||
# Call with featured=True - should still be cached
|
||||
await store_cache._get_cached_store_agents(
|
||||
await routes._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
@@ -135,11 +141,105 @@ class TestCacheDeletion:
|
||||
)
|
||||
assert mock_db.call_count == 3 # No additional DB call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_details_cache_delete(self):
|
||||
"""Test that specific agent details cache entries can be deleted."""
|
||||
mock_response = StoreAgentDetails(
|
||||
store_listing_version_id="version1",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="https://example.com/video.mp4",
|
||||
agent_image=["https://example.com/image.jpg"],
|
||||
creator="testuser",
|
||||
creator_avatar="https://example.com/avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["productivity"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=[],
|
||||
last_updated=datetime.datetime(2024, 1, 1),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_agent_details",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
# Clear cache first
|
||||
routes._get_cached_agent_details.cache_clear()
|
||||
|
||||
# First call - should hit database
|
||||
await routes._get_cached_agent_details(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
await routes._get_cached_agent_details(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert mock_db.call_count == 1 # No additional DB call
|
||||
|
||||
# Delete specific entry
|
||||
deleted = routes._get_cached_agent_details.cache_delete(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert deleted is True
|
||||
|
||||
# Call again - should hit database
|
||||
await routes._get_cached_agent_details(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert mock_db.call_count == 2 # New DB call after deletion
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_profile_cache_delete(self):
|
||||
"""Test that user profile cache entries can be deleted."""
|
||||
mock_response = ProfileDetails(
|
||||
name="Test User",
|
||||
username="testuser",
|
||||
description="Test profile",
|
||||
links=["https://example.com"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_user_profile",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
# Clear cache first
|
||||
routes._get_cached_user_profile.cache_clear()
|
||||
|
||||
# First call - should hit database
|
||||
await routes._get_cached_user_profile("user123")
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
await routes._get_cached_user_profile("user123")
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Different user - should hit database
|
||||
await routes._get_cached_user_profile("user456")
|
||||
assert mock_db.call_count == 2
|
||||
|
||||
# Delete specific user's cache
|
||||
deleted = routes._get_cached_user_profile.cache_delete("user123")
|
||||
assert deleted is True
|
||||
|
||||
# user123 should hit database again
|
||||
await routes._get_cached_user_profile("user123")
|
||||
assert mock_db.call_count == 3
|
||||
|
||||
# user456 should still be cached
|
||||
await routes._get_cached_user_profile("user456")
|
||||
assert mock_db.call_count == 3 # No additional DB call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info_after_deletions(self):
|
||||
"""Test that cache_info correctly reflects deletions."""
|
||||
# Clear all caches first
|
||||
store_cache._get_cached_store_agents.cache_clear()
|
||||
routes._get_cached_store_agents.cache_clear()
|
||||
|
||||
mock_response = StoreAgentsResponse(
|
||||
agents=[],
|
||||
@@ -158,7 +258,7 @@ class TestCacheDeletion:
|
||||
):
|
||||
# Add multiple entries
|
||||
for i in range(5):
|
||||
await store_cache._get_cached_store_agents(
|
||||
await routes._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=f"creator{i}",
|
||||
sorted_by=None,
|
||||
@@ -169,12 +269,12 @@ class TestCacheDeletion:
|
||||
)
|
||||
|
||||
# Check cache size
|
||||
info = store_cache._get_cached_store_agents.cache_info()
|
||||
info = routes._get_cached_store_agents.cache_info()
|
||||
assert info["size"] == 5
|
||||
|
||||
# Delete some entries
|
||||
for i in range(2):
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=False,
|
||||
creator=f"creator{i}",
|
||||
sorted_by=None,
|
||||
@@ -186,7 +286,7 @@ class TestCacheDeletion:
|
||||
assert deleted is True
|
||||
|
||||
# Check cache size after deletion
|
||||
info = store_cache._get_cached_store_agents.cache_info()
|
||||
info = routes._get_cached_store_agents.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -207,10 +307,10 @@ class TestCacheDeletion:
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
store_cache._get_cached_store_agents.cache_clear()
|
||||
routes._get_cached_store_agents.cache_clear()
|
||||
|
||||
# Test with all parameters
|
||||
await store_cache._get_cached_store_agents(
|
||||
await routes._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
@@ -222,7 +322,7 @@ class TestCacheDeletion:
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Delete with exact same parameters
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
@@ -234,7 +334,7 @@ class TestCacheDeletion:
|
||||
assert deleted is True
|
||||
|
||||
# Try to delete with slightly different parameters
|
||||
deleted = store_cache._get_cached_store_agents.cache_delete(
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
@@ -245,6 +345,150 @@ class TestCacheDeletion:
|
||||
)
|
||||
assert deleted is False # Different parameters, not in cache
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_submissions_cache_page_size_consistency(self):
|
||||
"""
|
||||
Test that _clear_submissions_cache uses the correct page_size.
|
||||
This test ensures that if the default page_size in routes changes,
|
||||
the hardcoded value in _clear_submissions_cache must also change.
|
||||
"""
|
||||
from backend.server.v2.store.model import StoreSubmissionsResponse
|
||||
|
||||
mock_response = StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_submissions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
# Clear cache first
|
||||
routes._get_cached_submissions.cache_clear()
|
||||
|
||||
# Populate cache with multiple pages using the default page_size
|
||||
DEFAULT_PAGE_SIZE = 20 # This should match the default in routes.py
|
||||
user_id = "test_user"
|
||||
|
||||
# Add entries for pages 1-5
|
||||
for page in range(1, 6):
|
||||
await routes._get_cached_submissions(
|
||||
user_id=user_id, page=page, page_size=DEFAULT_PAGE_SIZE
|
||||
)
|
||||
|
||||
# Verify cache has entries
|
||||
cache_info_before = routes._get_cached_submissions.cache_info()
|
||||
assert cache_info_before["size"] == 5
|
||||
|
||||
# Call _clear_submissions_cache
|
||||
routes._clear_submissions_cache(user_id, num_pages=20)
|
||||
|
||||
# All entries should be cleared
|
||||
cache_info_after = routes._get_cached_submissions.cache_info()
|
||||
assert (
|
||||
cache_info_after["size"] == 0
|
||||
), "Cache should be empty after _clear_submissions_cache"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_submissions_cache_detects_page_size_mismatch(self):
|
||||
"""
|
||||
Test that detects if _clear_submissions_cache is using wrong page_size.
|
||||
If this test fails, it means the hardcoded page_size in _clear_submissions_cache
|
||||
doesn't match the default page_size used in the routes.
|
||||
"""
|
||||
from backend.server.v2.store.model import StoreSubmissionsResponse
|
||||
|
||||
mock_response = StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_submissions",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
# Clear cache first
|
||||
routes._get_cached_submissions.cache_clear()
|
||||
|
||||
# WRONG_PAGE_SIZE simulates what happens if someone changes
|
||||
# the default page_size in routes but forgets to update _clear_submissions_cache
|
||||
WRONG_PAGE_SIZE = 25 # Different from the hardcoded value in cache.py
|
||||
user_id = "test_user"
|
||||
|
||||
# Populate cache with the "wrong" page_size
|
||||
for page in range(1, 6):
|
||||
await routes._get_cached_submissions(
|
||||
user_id=user_id, page=page, page_size=WRONG_PAGE_SIZE
|
||||
)
|
||||
|
||||
# Verify cache has entries
|
||||
cache_info_before = routes._get_cached_submissions.cache_info()
|
||||
assert cache_info_before["size"] == 5
|
||||
|
||||
# Call _clear_submissions_cache (which uses page_size=20 hardcoded)
|
||||
routes._clear_submissions_cache(user_id, num_pages=20)
|
||||
|
||||
# If page_size is mismatched, entries won't be cleared
|
||||
cache_info_after = routes._get_cached_submissions.cache_info()
|
||||
|
||||
# This assertion will FAIL if _clear_submissions_cache uses wrong page_size
|
||||
assert (
|
||||
cache_info_after["size"] == 5
|
||||
), "Cache entries with different page_size should NOT be cleared (this is expected)"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_my_agents_cache_needs_clearing_too(self):
|
||||
"""
|
||||
Test that demonstrates _get_cached_my_agents also needs cache clearing.
|
||||
Currently there's no _clear_my_agents_cache function, but there should be.
|
||||
"""
|
||||
from backend.server.v2.store.model import MyAgentsResponse
|
||||
|
||||
mock_response = MyAgentsResponse(
|
||||
agents=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_my_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
routes._get_cached_my_agents.cache_clear()
|
||||
|
||||
DEFAULT_PAGE_SIZE = 20
|
||||
user_id = "test_user"
|
||||
|
||||
# Populate cache
|
||||
for page in range(1, 6):
|
||||
await routes._get_cached_my_agents(
|
||||
user_id=user_id, page=page, page_size=DEFAULT_PAGE_SIZE
|
||||
)
|
||||
|
||||
cache_info = routes._get_cached_my_agents.cache_info()
|
||||
assert cache_info["size"] == 5
|
||||
|
||||
# NOTE: Currently there's no _clear_my_agents_cache function
|
||||
# If we implement one, it should clear all pages consistently
|
||||
# For now we document this as a TODO
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
|
||||
@@ -329,3 +329,7 @@ class WebsocketServer(AppProcess):
|
||||
port=Config().websocket_server_port,
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")
|
||||
|
||||
@@ -12,7 +12,6 @@ Provides decorators for caching function results with support for:
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -59,7 +58,9 @@ def _get_cache_pool() -> ConnectionPool:
|
||||
return _cache_pool
|
||||
|
||||
|
||||
redis = Redis(connection_pool=_get_cache_pool())
|
||||
def _get_redis_client() -> Redis:
|
||||
"""Get a Redis client from the connection pool."""
|
||||
return Redis(connection_pool=_get_cache_pool())
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -109,11 +110,11 @@ def _make_hashable_key(
|
||||
return (hashable_args, hashable_kwargs)
|
||||
|
||||
|
||||
def _make_redis_key(key: tuple[Any, ...], func_name: str) -> str:
|
||||
def _make_redis_key(key: tuple[Any, ...]) -> str:
|
||||
"""Convert a hashable key tuple to a Redis key string."""
|
||||
# Ensure key is already hashable
|
||||
hashable_key = key if isinstance(key, tuple) else (key,)
|
||||
return f"cache:{func_name}:{hash(hashable_key)}"
|
||||
return f"cache:{hash(hashable_key)}"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@@ -176,6 +177,9 @@ def cached(
|
||||
def _get_from_redis(redis_key: str) -> Any | None:
|
||||
"""Get value from Redis, optionally refreshing TTL."""
|
||||
try:
|
||||
import pickle
|
||||
|
||||
redis = _get_redis_client()
|
||||
if refresh_ttl_on_get:
|
||||
# Use GETEX to get value and refresh expiry atomically
|
||||
cached_bytes = redis.getex(redis_key, ex=ttl_seconds)
|
||||
@@ -193,6 +197,9 @@ def cached(
|
||||
def _set_to_redis(redis_key: str, value: Any) -> None:
|
||||
"""Set value in Redis with TTL."""
|
||||
try:
|
||||
import pickle
|
||||
|
||||
redis = _get_redis_client()
|
||||
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
redis.setex(redis_key, ttl_seconds, pickled_value)
|
||||
except Exception as e:
|
||||
@@ -232,15 +239,13 @@ def cached(
|
||||
loop = None
|
||||
|
||||
if loop not in _event_loop_locks:
|
||||
return _event_loop_locks.setdefault(loop, asyncio.Lock())
|
||||
_event_loop_locks[loop] = asyncio.Lock()
|
||||
return _event_loop_locks[loop]
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
redis_key = _make_redis_key(key) if shared_cache else ""
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
@@ -285,9 +290,7 @@ def cached(
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
redis_key = _make_redis_key(key) if shared_cache else ""
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
@@ -329,14 +332,13 @@ def cached(
|
||||
def cache_clear(pattern: str | None = None) -> None:
|
||||
"""Clear cache entries. If pattern provided, clear matching entries."""
|
||||
if shared_cache:
|
||||
redis = _get_redis_client()
|
||||
if pattern:
|
||||
# Clear entries matching pattern
|
||||
keys = list(
|
||||
redis.scan_iter(f"cache:{target_func.__name__}:{pattern}")
|
||||
)
|
||||
keys = list(redis.scan_iter(f"cache:{pattern}", count=100))
|
||||
else:
|
||||
# Clear all cache keys
|
||||
keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
|
||||
keys = list(redis.scan_iter("cache:*", count=100))
|
||||
|
||||
if keys:
|
||||
pipeline = redis.pipeline()
|
||||
@@ -354,7 +356,8 @@ def cached(
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
if shared_cache:
|
||||
cache_keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
|
||||
redis = _get_redis_client()
|
||||
cache_keys = list(redis.scan_iter("cache:*"))
|
||||
return {
|
||||
"size": len(cache_keys),
|
||||
"maxsize": None, # Redis manages its own size
|
||||
@@ -371,7 +374,8 @@ def cached(
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if shared_cache:
|
||||
redis_key = _make_redis_key(key, target_func.__name__)
|
||||
redis = _get_redis_client()
|
||||
redis_key = _make_redis_key(key)
|
||||
if redis.exists(redis_key):
|
||||
redis.delete(redis_key)
|
||||
return True
|
||||
|
||||
@@ -33,14 +33,12 @@ def get_database_manager_client() -> "DatabaseManagerClient":
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_async_client(
|
||||
should_retry: bool = True,
|
||||
) -> "DatabaseManagerAsyncClient":
|
||||
def get_database_manager_async_client() -> "DatabaseManagerAsyncClient":
|
||||
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
|
||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=True)
|
||||
|
||||
|
||||
@thread_cached
|
||||
|
||||
@@ -86,9 +86,3 @@ class GraphValidationError(ValueError):
|
||||
for node_id, errors in self.node_errors.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DatabaseError(Exception):
|
||||
"""Raised when there is an error interacting with the database"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -35,12 +35,6 @@ class Flag(str, Enum):
|
||||
AI_ACTIVITY_STATUS = "ai-agent-execution-summary"
|
||||
BETA_BLOCKS = "beta-blocks"
|
||||
AGENT_ACTIVITY = "agent-activity"
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
"""Check if LaunchDarkly is configured with an SDK key."""
|
||||
return bool(settings.secrets.launch_darkly_sdk_key)
|
||||
|
||||
|
||||
def get_client() -> LDClient:
|
||||
@@ -63,9 +57,9 @@ def initialize_launchdarkly() -> None:
|
||||
config = Config(sdk_key)
|
||||
ldclient.set_config(config)
|
||||
|
||||
global _is_initialized
|
||||
_is_initialized = True
|
||||
if ldclient.get().is_initialized():
|
||||
global _is_initialized
|
||||
_is_initialized = True
|
||||
logger.info("LaunchDarkly client initialized successfully")
|
||||
else:
|
||||
logger.error("LaunchDarkly client failed to initialize")
|
||||
@@ -218,8 +212,7 @@ def feature_flag(
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
"LaunchDarkly not initialized, "
|
||||
f"using default {flag_key}={repr(default)}"
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
@@ -233,9 +226,8 @@ def feature_flag(
|
||||
else:
|
||||
# Log warning and use default for non-boolean values
|
||||
logger.warning(
|
||||
f"Feature flag {flag_key} returned non-boolean value: "
|
||||
f"{repr(flag_value)} (type: {type(flag_value).__name__}). "
|
||||
f"Using default value {repr(default)}"
|
||||
f"Feature flag {flag_key} returned non-boolean value: {flag_value} (type: {type(flag_value).__name__}). "
|
||||
f"Using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Type, TypeVar, overload
|
||||
from typing import Any, Type, TypeGuard, TypeVar, overload
|
||||
|
||||
import jsonschema
|
||||
import orjson
|
||||
from fastapi.encoders import jsonable_encoder as to_dict
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from prisma import Json
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .truncate import truncate
|
||||
from .type import type_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Precompiled regex to remove PostgreSQL-incompatible control characters
|
||||
# Removes \u0000-\u0008, \u000B-\u000C, \u000E-\u001F, \u007F (keeps tab \u0009, newline \u000A, carriage return \u000D)
|
||||
POSTGRES_CONTROL_CHARS = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]")
|
||||
|
||||
|
||||
def to_dict(data) -> dict:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
return jsonable_encoder(data)
|
||||
|
||||
|
||||
def dumps(
|
||||
data: Any, *args: Any, indent: int | None = None, option: int = 0, **kwargs: Any
|
||||
) -> str:
|
||||
@@ -105,57 +109,34 @@ def validate_with_jsonschema(
|
||||
return str(e)
|
||||
|
||||
|
||||
def _sanitize_string(value: str) -> str:
|
||||
"""Remove PostgreSQL-incompatible control characters from string."""
|
||||
return POSTGRES_CONTROL_CHARS.sub("", value)
|
||||
def is_list_of_basemodels(value: object) -> TypeGuard[list[BaseModel]]:
|
||||
return isinstance(value, list) and all(
|
||||
isinstance(item, BaseModel) for item in value
|
||||
)
|
||||
|
||||
|
||||
def sanitize_json(data: Any) -> Any:
|
||||
try:
|
||||
# Use two-pass approach for consistent string sanitization:
|
||||
# 1. First convert to basic JSON-serializable types (handles Pydantic models)
|
||||
# 2. Then sanitize strings in the result
|
||||
basic_result = to_dict(data)
|
||||
return to_dict(basic_result, custom_encoder={str: _sanitize_string})
|
||||
except Exception as e:
|
||||
# Log the failure and fall back to string representation
|
||||
logger.error(
|
||||
"SafeJson fallback to string representation due to serialization error: %s (%s). "
|
||||
"Data type: %s, Data preview: %s",
|
||||
type(e).__name__,
|
||||
truncate(str(e), 200),
|
||||
type(data).__name__,
|
||||
truncate(str(data), 100),
|
||||
)
|
||||
|
||||
# Ultimate fallback: convert to string representation and sanitize
|
||||
return _sanitize_string(str(data))
|
||||
def convert_pydantic_to_json(output_data: Any) -> Any:
|
||||
if isinstance(output_data, BaseModel):
|
||||
return output_data.model_dump()
|
||||
if is_list_of_basemodels(output_data):
|
||||
return [item.model_dump() for item in output_data]
|
||||
return output_data
|
||||
|
||||
|
||||
class SafeJson(Json):
|
||||
def SafeJson(data: Any) -> Json:
|
||||
"""
|
||||
Safely serialize data and return Prisma's Json type.
|
||||
Sanitizes control characters to prevent PostgreSQL 22P05 errors.
|
||||
|
||||
This function:
|
||||
1. Converts Pydantic models to dicts (recursively using to_dict)
|
||||
2. Recursively removes PostgreSQL-incompatible control characters from strings
|
||||
3. Returns a Prisma Json object safe for database storage
|
||||
|
||||
Uses to_dict (jsonable_encoder) with a custom encoder to handle both Pydantic
|
||||
conversion and control character sanitization in a two-pass approach.
|
||||
|
||||
Args:
|
||||
data: Input data to sanitize and convert to Json
|
||||
|
||||
Returns:
|
||||
Prisma Json object with control characters removed
|
||||
|
||||
Examples:
|
||||
>>> SafeJson({"text": "Hello\\x00World"}) # null char removed
|
||||
>>> SafeJson({"path": "C:\\\\temp"}) # backslashes preserved
|
||||
>>> SafeJson({"data": "Text\\\\u0000here"}) # literal backslash-u preserved
|
||||
Sanitizes null bytes to prevent PostgreSQL 22P05 errors.
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
json_string = data.model_dump_json(
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
else:
|
||||
json_string = dumps(data, default=lambda v: None)
|
||||
|
||||
def __init__(self, data: Any):
|
||||
super().__init__(sanitize_json(data))
|
||||
# Remove PostgreSQL-incompatible control characters in single regex operation
|
||||
sanitized_json = POSTGRES_CONTROL_CHARS.sub("", json_string)
|
||||
return Json(json.loads(sanitized_json))
|
||||
|
||||
@@ -8,7 +8,10 @@ settings = Settings()
|
||||
def configure_logging():
|
||||
import autogpt_libs.logging.config
|
||||
|
||||
if not is_structured_logging_enabled():
|
||||
if (
|
||||
settings.config.behave_as == BehaveAs.LOCAL
|
||||
or settings.config.app_env == AppEnvironment.LOCAL
|
||||
):
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=False)
|
||||
else:
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=True)
|
||||
@@ -17,14 +20,6 @@ def configure_logging():
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def is_structured_logging_enabled() -> bool:
|
||||
"""Check if structured logging (cloud logging) is enabled."""
|
||||
return not (
|
||||
settings.config.behave_as == BehaveAs.LOCAL
|
||||
or settings.config.app_env == AppEnvironment.LOCAL
|
||||
)
|
||||
|
||||
|
||||
class TruncatedLogger:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,17 +3,13 @@ from enum import Enum
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import SecretStr
|
||||
from sentry_sdk.integrations import DidNotEnable
|
||||
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
||||
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
||||
from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
from backend.util import feature_flag
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscordChannel(str, Enum):
|
||||
@@ -23,12 +19,6 @@ class DiscordChannel(str, Enum):
|
||||
|
||||
def sentry_init():
|
||||
sentry_dsn = settings.secrets.sentry_dsn
|
||||
integrations = []
|
||||
if feature_flag.is_configured():
|
||||
try:
|
||||
integrations.append(LaunchDarklyIntegration(feature_flag.get_client()))
|
||||
except DidNotEnable as e:
|
||||
logger.error(f"Error enabling LaunchDarklyIntegration for Sentry: {e}")
|
||||
sentry_sdk.init(
|
||||
dsn=sentry_dsn,
|
||||
traces_sample_rate=1.0,
|
||||
@@ -41,8 +31,7 @@ def sentry_init():
|
||||
AnthropicIntegration(
|
||||
include_prompts=False,
|
||||
),
|
||||
]
|
||||
+ integrations,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,9 +8,18 @@ from typing import Optional
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.settings import set_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_SERVICE_NAME = "MainProcess"
|
||||
|
||||
|
||||
def get_service_name():
|
||||
return _SERVICE_NAME
|
||||
|
||||
|
||||
def set_service_name(name: str):
|
||||
global _SERVICE_NAME
|
||||
_SERVICE_NAME = name
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
@@ -19,8 +28,7 @@ class AppProcess(ABC):
|
||||
"""
|
||||
|
||||
process: Optional[Process] = None
|
||||
_shutting_down: bool = False
|
||||
_cleaned_up: bool = False
|
||||
cleaned_up = False
|
||||
|
||||
if "forkserver" in get_all_start_methods():
|
||||
set_start_method("forkserver", force=True)
|
||||
@@ -44,6 +52,7 @@ class AppProcess(ABC):
|
||||
def service_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
@@ -65,8 +74,7 @@ class AppProcess(ABC):
|
||||
self.run()
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] 🛑 Terminating because of {type(e).__name__}: {e}", # noqa
|
||||
exc_info=e if not isinstance(e, SystemExit) else None,
|
||||
f"[{self.service_name}] Termination request: {type(e).__name__}; {e} executing cleanup."
|
||||
)
|
||||
# Send error to Sentry before cleanup
|
||||
if not isinstance(e, (KeyboardInterrupt, SystemExit)):
|
||||
@@ -77,12 +85,8 @@ class AppProcess(ABC):
|
||||
except Exception:
|
||||
pass # Silently ignore if Sentry isn't available
|
||||
finally:
|
||||
if not self._cleaned_up:
|
||||
self._cleaned_up = True
|
||||
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] ✅ Cleanup done")
|
||||
logger.info(f"[{self.service_name}] 🛑 Terminated")
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] Terminated.")
|
||||
|
||||
@staticmethod
|
||||
def llprint(message: str):
|
||||
@@ -93,8 +97,8 @@ class AppProcess(ABC):
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
if not self._shutting_down:
|
||||
self._shutting_down = True
|
||||
if not self.cleaned_up:
|
||||
self.cleaned_up = True
|
||||
sys.exit(0)
|
||||
else:
|
||||
self.llprint(
|
||||
|
||||
@@ -13,7 +13,7 @@ import idna
|
||||
from aiohttp import FormData, abc
|
||||
from tenacity import retry, retry_if_result, wait_exponential_jitter
|
||||
|
||||
from backend.util.json import loads
|
||||
from backend.util.json import json
|
||||
|
||||
# Retry status codes for which we will automatically retry the request
|
||||
THROTTLE_RETRY_STATUS_CODES: set[int] = {429, 500, 502, 503, 504, 408}
|
||||
@@ -175,15 +175,10 @@ async def validate_url(
|
||||
f"for hostname {ascii_hostname} is not allowed."
|
||||
)
|
||||
|
||||
# Reconstruct the netloc with IDNA-encoded hostname and preserve port
|
||||
netloc = ascii_hostname
|
||||
if parsed.port:
|
||||
netloc = f"{ascii_hostname}:{parsed.port}"
|
||||
|
||||
return (
|
||||
URL(
|
||||
parsed.scheme,
|
||||
netloc,
|
||||
ascii_hostname,
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
@@ -264,7 +259,7 @@ class Response:
|
||||
"""
|
||||
Parse the body as JSON and return the resulting Python object.
|
||||
"""
|
||||
return loads(
|
||||
return json.loads(
|
||||
self.content.decode(encoding or "utf-8", errors="replace"), **kwargs
|
||||
)
|
||||
|
||||
|
||||
@@ -13,80 +13,41 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from backend.util.settings import get_service_name
|
||||
from backend.util.process import get_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Alert threshold for excessive retries
|
||||
EXCESSIVE_RETRY_THRESHOLD = 50
|
||||
|
||||
# Rate limiting for alerts - track last alert time per function+error combination
|
||||
_alert_rate_limiter = {}
|
||||
_rate_limiter_lock = threading.Lock()
|
||||
ALERT_RATE_LIMIT_SECONDS = 300 # 5 minutes between same alerts
|
||||
|
||||
|
||||
def should_send_alert(func_name: str, exception: Exception, context: str = "") -> bool:
|
||||
"""Check if we should send an alert based on rate limiting."""
|
||||
# Create a unique key for this function+error+context combination
|
||||
error_signature = (
|
||||
f"{context}:{func_name}:{type(exception).__name__}:{str(exception)[:100]}"
|
||||
)
|
||||
current_time = time.time()
|
||||
|
||||
with _rate_limiter_lock:
|
||||
last_alert_time = _alert_rate_limiter.get(error_signature, 0)
|
||||
if current_time - last_alert_time >= ALERT_RATE_LIMIT_SECONDS:
|
||||
_alert_rate_limiter[error_signature] = current_time
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def send_rate_limited_discord_alert(
|
||||
func_name: str, exception: Exception, context: str, alert_msg: str, channel=None
|
||||
) -> bool:
|
||||
"""
|
||||
Send a Discord alert with rate limiting.
|
||||
|
||||
Returns True if alert was sent, False if rate limited.
|
||||
"""
|
||||
if not should_send_alert(func_name, exception, context):
|
||||
return False
|
||||
|
||||
try:
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
from backend.util.metrics import DiscordChannel
|
||||
|
||||
notification_client = get_notification_manager_client()
|
||||
notification_client.discord_system_alert(
|
||||
alert_msg, channel or DiscordChannel.PLATFORM
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as alert_error:
|
||||
logger.error(f"Failed to send Discord alert: {alert_error}")
|
||||
return False
|
||||
|
||||
|
||||
def _send_critical_retry_alert(
|
||||
func_name: str, attempt_number: int, exception: Exception, context: str = ""
|
||||
):
|
||||
"""Send alert when a function is approaching the retry failure threshold."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
|
||||
prefix = f"{context}: " if context else ""
|
||||
if send_rate_limited_discord_alert(
|
||||
func_name,
|
||||
exception,
|
||||
context,
|
||||
f"🚨 CRITICAL: Operation Approaching Failure Threshold: {prefix}'{func_name}'\n\n"
|
||||
f"Current attempt: {attempt_number}/{EXCESSIVE_RETRY_THRESHOLD}\n"
|
||||
f"Error: {type(exception).__name__}: {exception}\n\n"
|
||||
f"This operation is about to fail permanently. Investigate immediately.",
|
||||
):
|
||||
notification_client = get_notification_manager_client()
|
||||
|
||||
prefix = f"{context}: " if context else ""
|
||||
alert_msg = (
|
||||
f"🚨 CRITICAL: Operation Approaching Failure Threshold: {prefix}'{func_name}'\n\n"
|
||||
f"Current attempt: {attempt_number}/{EXCESSIVE_RETRY_THRESHOLD}\n"
|
||||
f"Error: {type(exception).__name__}: {exception}\n\n"
|
||||
f"This operation is about to fail permanently. Investigate immediately."
|
||||
)
|
||||
|
||||
notification_client.discord_system_alert(alert_msg)
|
||||
logger.critical(
|
||||
f"CRITICAL ALERT SENT: Operation {func_name} at attempt {attempt_number}"
|
||||
)
|
||||
|
||||
except Exception as alert_error:
|
||||
logger.error(f"Failed to send critical retry alert: {alert_error}")
|
||||
# Don't let alerting failures break the main flow
|
||||
|
||||
|
||||
def _create_retry_callback(context: str = ""):
|
||||
"""Create a retry callback with optional context."""
|
||||
@@ -105,7 +66,7 @@ def _create_retry_callback(context: str = ""):
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
else:
|
||||
# Retry attempt - send critical alert only once at threshold (rate limited)
|
||||
# Retry attempt - send critical alert only once at threshold
|
||||
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
|
||||
_send_critical_retry_alert(
|
||||
func_name, attempt_number, exception, context
|
||||
@@ -170,7 +131,7 @@ def _log_prefix(resource_name: str, conn_id: str):
|
||||
def conn_retry(
|
||||
resource_name: str,
|
||||
action_name: str,
|
||||
max_retry: int = 100,
|
||||
max_retry: int = 5,
|
||||
max_wait: float = 30,
|
||||
):
|
||||
conn_id = str(uuid4())
|
||||
@@ -178,29 +139,10 @@ def conn_retry(
|
||||
def on_retry(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
exception = retry_state.outcome.exception()
|
||||
attempt_number = retry_state.attempt_number
|
||||
func_name = getattr(retry_state.fn, "__name__", "unknown")
|
||||
|
||||
if retry_state.outcome.failed and retry_state.next_action is None:
|
||||
logger.error(f"{prefix} {action_name} failed after retries: {exception}")
|
||||
else:
|
||||
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
|
||||
if send_rate_limited_discord_alert(
|
||||
func_name,
|
||||
exception,
|
||||
f"{resource_name}_infrastructure",
|
||||
f"🚨 **Critical Infrastructure Connection Issue**\n"
|
||||
f"Resource: {resource_name}\n"
|
||||
f"Action: {action_name}\n"
|
||||
f"Function: {func_name}\n"
|
||||
f"Current attempt: {attempt_number}/{max_retry + 1}\n"
|
||||
f"Error: {type(exception).__name__}: {str(exception)[:200]}{'...' if len(str(exception)) > 200 else ''}\n\n"
|
||||
f"Infrastructure component is approaching failure threshold. Investigate immediately.",
|
||||
):
|
||||
logger.critical(
|
||||
f"INFRASTRUCTURE ALERT SENT: {resource_name} at {attempt_number} attempts"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"{prefix} {action_name} failed: {exception}. Retrying now..."
|
||||
)
|
||||
@@ -276,8 +218,8 @@ def continuous_retry(*, retry_delay: float = 1.0):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
counter = 0
|
||||
while True:
|
||||
counter = 0
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.retry import (
|
||||
ALERT_RATE_LIMIT_SECONDS,
|
||||
_alert_rate_limiter,
|
||||
_rate_limiter_lock,
|
||||
_send_critical_retry_alert,
|
||||
conn_retry,
|
||||
create_retry_decorator,
|
||||
should_send_alert,
|
||||
)
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
|
||||
def test_conn_retry_sync_function():
|
||||
@@ -58,194 +47,3 @@ async def test_conn_retry_async_function():
|
||||
with pytest.raises(ValueError) as e:
|
||||
await test_function()
|
||||
assert str(e.value) == "Test error"
|
||||
|
||||
|
||||
class TestRetryRateLimiting:
|
||||
"""Test the rate limiting functionality for critical retry alerts."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset rate limiter state before each test."""
|
||||
with _rate_limiter_lock:
|
||||
_alert_rate_limiter.clear()
|
||||
|
||||
def test_should_send_alert_allows_first_occurrence(self):
|
||||
"""Test that the first occurrence of an error allows alert."""
|
||||
exc = ValueError("test error")
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
|
||||
def test_should_send_alert_rate_limits_duplicate(self):
|
||||
"""Test that duplicate errors are rate limited."""
|
||||
exc = ValueError("test error")
|
||||
|
||||
# First call should be allowed
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
|
||||
# Second call should be rate limited
|
||||
assert should_send_alert("test_func", exc, "test_context") is False
|
||||
|
||||
def test_should_send_alert_allows_different_errors(self):
|
||||
"""Test that different errors are allowed even if same function."""
|
||||
exc1 = ValueError("error 1")
|
||||
exc2 = ValueError("error 2")
|
||||
|
||||
# First error should be allowed
|
||||
assert should_send_alert("test_func", exc1, "test_context") is True
|
||||
|
||||
# Different error should also be allowed
|
||||
assert should_send_alert("test_func", exc2, "test_context") is True
|
||||
|
||||
def test_should_send_alert_allows_different_contexts(self):
|
||||
"""Test that same error in different contexts is allowed."""
|
||||
exc = ValueError("test error")
|
||||
|
||||
# First context should be allowed
|
||||
assert should_send_alert("test_func", exc, "context1") is True
|
||||
|
||||
# Different context should also be allowed
|
||||
assert should_send_alert("test_func", exc, "context2") is True
|
||||
|
||||
def test_should_send_alert_allows_different_functions(self):
|
||||
"""Test that same error in different functions is allowed."""
|
||||
exc = ValueError("test error")
|
||||
|
||||
# First function should be allowed
|
||||
assert should_send_alert("func1", exc, "test_context") is True
|
||||
|
||||
# Different function should also be allowed
|
||||
assert should_send_alert("func2", exc, "test_context") is True
|
||||
|
||||
def test_should_send_alert_respects_time_window(self):
|
||||
"""Test that alerts are allowed again after the rate limit window."""
|
||||
exc = ValueError("test error")
|
||||
|
||||
# First call should be allowed
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
|
||||
# Immediately after should be rate limited
|
||||
assert should_send_alert("test_func", exc, "test_context") is False
|
||||
|
||||
# Mock time to simulate passage of rate limit window
|
||||
current_time = time.time()
|
||||
with patch("backend.util.retry.time.time") as mock_time:
|
||||
# Simulate time passing beyond rate limit window
|
||||
mock_time.return_value = current_time + ALERT_RATE_LIMIT_SECONDS + 1
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
|
||||
def test_should_send_alert_thread_safety(self):
|
||||
"""Test that rate limiting is thread-safe."""
|
||||
exc = ValueError("test error")
|
||||
results = []
|
||||
|
||||
def check_alert():
|
||||
result = should_send_alert("test_func", exc, "test_context")
|
||||
results.append(result)
|
||||
|
||||
# Create multiple threads trying to send the same alert
|
||||
threads = [threading.Thread(target=check_alert) for _ in range(10)]
|
||||
|
||||
# Start all threads
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one thread should have been allowed to send the alert
|
||||
assert sum(results) == 1
|
||||
assert len([r for r in results if r is True]) == 1
|
||||
assert len([r for r in results if r is False]) == 9
|
||||
|
||||
@patch("backend.util.clients.get_notification_manager_client")
|
||||
def test_send_critical_retry_alert_rate_limiting(self, mock_get_client):
|
||||
"""Test that _send_critical_retry_alert respects rate limiting."""
|
||||
mock_client = Mock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
exc = ValueError("spend_credits API error")
|
||||
|
||||
# First alert should be sent
|
||||
_send_critical_retry_alert("spend_credits", 50, exc, "Service communication")
|
||||
assert mock_client.discord_system_alert.call_count == 1
|
||||
|
||||
# Second identical alert should be rate limited (not sent)
|
||||
_send_critical_retry_alert("spend_credits", 50, exc, "Service communication")
|
||||
assert mock_client.discord_system_alert.call_count == 1 # Still 1, not 2
|
||||
|
||||
# Different error should be allowed
|
||||
exc2 = ValueError("different API error")
|
||||
_send_critical_retry_alert("spend_credits", 50, exc2, "Service communication")
|
||||
assert mock_client.discord_system_alert.call_count == 2
|
||||
|
||||
@patch("backend.util.clients.get_notification_manager_client")
|
||||
def test_send_critical_retry_alert_handles_notification_failure(
|
||||
self, mock_get_client
|
||||
):
|
||||
"""Test that notification failures don't break the rate limiter."""
|
||||
mock_client = Mock()
|
||||
mock_client.discord_system_alert.side_effect = Exception("Notification failed")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
exc = ValueError("test error")
|
||||
|
||||
# Should not raise exception even if notification fails
|
||||
_send_critical_retry_alert("test_func", 50, exc, "test_context")
|
||||
|
||||
# Rate limiter should still work for subsequent calls
|
||||
assert should_send_alert("test_func", exc, "test_context") is False
|
||||
|
||||
def test_error_signature_generation(self):
|
||||
"""Test that error signatures are generated correctly for rate limiting."""
|
||||
# Test with long exception message (should be truncated to 100 chars)
|
||||
long_message = "x" * 200
|
||||
exc = ValueError(long_message)
|
||||
|
||||
# Should not raise exception and should work normally
|
||||
assert should_send_alert("test_func", exc, "test_context") is True
|
||||
assert should_send_alert("test_func", exc, "test_context") is False
|
||||
|
||||
def test_real_world_scenario_spend_credits_spam(self):
|
||||
"""Test the real-world scenario that was causing spam."""
|
||||
# Simulate the exact error that was causing issues
|
||||
exc = Exception(
|
||||
"HTTP 500: Server error '500 Internal Server Error' for url 'http://autogpt-database-manager.prod-agpt.svc.cluster.local:8005/spend_credits'"
|
||||
)
|
||||
|
||||
# First 50 attempts reach threshold - should send alert
|
||||
with patch(
|
||||
"backend.util.clients.get_notification_manager_client"
|
||||
) as mock_get_client:
|
||||
mock_client = Mock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
_send_critical_retry_alert(
|
||||
"_call_method_sync", 50, exc, "Service communication"
|
||||
)
|
||||
assert mock_client.discord_system_alert.call_count == 1
|
||||
|
||||
# Next 950 failures should not send alerts (rate limited)
|
||||
for _ in range(950):
|
||||
_send_critical_retry_alert(
|
||||
"_call_method_sync", 50, exc, "Service communication"
|
||||
)
|
||||
|
||||
# Still only 1 alert sent total
|
||||
assert mock_client.discord_system_alert.call_count == 1
|
||||
|
||||
@patch("backend.util.clients.get_notification_manager_client")
|
||||
def test_retry_decorator_with_excessive_failures(self, mock_get_client):
|
||||
"""Test retry decorator behavior when it hits the alert threshold."""
|
||||
mock_client = Mock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
@create_retry_decorator(
|
||||
max_attempts=60, max_wait=0.1
|
||||
) # More than EXCESSIVE_RETRY_THRESHOLD, but fast
|
||||
def always_failing_function():
|
||||
raise ValueError("persistent failure")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
always_failing_function()
|
||||
|
||||
# Should have sent exactly one alert at the threshold
|
||||
assert mock_client.discord_system_alert.call_count == 1
|
||||
|
||||
@@ -4,12 +4,9 @@ import concurrent.futures
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -31,12 +28,11 @@ from fastapi import FastAPI, Request, responses
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
|
||||
import backend.util.exceptions as exceptions
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
from backend.util.retry import conn_retry, create_retry_decorator
|
||||
from backend.util.settings import Config, get_service_name
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
@@ -114,44 +110,14 @@ class BaseAppService(AppProcess, ABC):
|
||||
return target_host
|
||||
|
||||
def run_service(self) -> None:
|
||||
# HACK: run the main event loop outside the main thread to disable Uvicorn's
|
||||
# internal signal handlers, since there is no config option for this :(
|
||||
shared_asyncio_thread = threading.Thread(
|
||||
target=self._run_shared_event_loop,
|
||||
daemon=True,
|
||||
name=f"{self.service_name}-shared-event-loop",
|
||||
)
|
||||
shared_asyncio_thread.start()
|
||||
shared_asyncio_thread.join()
|
||||
|
||||
def _run_shared_event_loop(self) -> None:
|
||||
try:
|
||||
self.shared_event_loop.run_forever()
|
||||
finally:
|
||||
logger.info(f"[{self.service_name}] 🛑 Shared event loop stopped")
|
||||
self.shared_event_loop.close() # ensure held resources are released
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop).result()
|
||||
|
||||
def run(self):
|
||||
self.shared_event_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.shared_event_loop)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
**💡 Overriding `AppService.lifespan` may be a more convenient option.**
|
||||
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
e.g. disconnecting from a database or terminating child processes.
|
||||
|
||||
**Note:** if you override this method in a subclass, it must call
|
||||
`super().cleanup()` *at the end*!
|
||||
"""
|
||||
# Stop the shared event loop to allow resource clean-up
|
||||
self.shared_event_loop.call_soon_threadsafe(self.shared_event_loop.stop)
|
||||
|
||||
super().cleanup()
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class RemoteCallError(BaseModel):
|
||||
@@ -212,7 +178,6 @@ EXCEPTION_MAPPING = {
|
||||
|
||||
class AppService(BaseAppService, ABC):
|
||||
fastapi_app: FastAPI
|
||||
http_server: uvicorn.Server | None = None
|
||||
log_level: str = "info"
|
||||
|
||||
def set_log_level(self, log_level: str):
|
||||
@@ -224,10 +189,11 @@ class AppService(BaseAppService, ABC):
|
||||
def _handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: Request, exc: Exception):
|
||||
if log_error:
|
||||
logger.error(
|
||||
f"{request.method} {request.url.path} failed: {exc}",
|
||||
exc_info=exc if status_code == 500 else None,
|
||||
)
|
||||
if status_code == 500:
|
||||
log = logger.exception
|
||||
else:
|
||||
log = logger.error
|
||||
log(f"{request.method} {request.url.path} failed: {exc}")
|
||||
return responses.JSONResponse(
|
||||
status_code=status_code,
|
||||
content=RemoteCallError(
|
||||
@@ -289,13 +255,13 @@ class AppService(BaseAppService, ABC):
|
||||
|
||||
return sync_endpoint
|
||||
|
||||
@conn_retry("FastAPI server", "Running FastAPI server")
|
||||
@conn_retry("FastAPI server", "Starting FastAPI server")
|
||||
def __start_fastapi(self):
|
||||
logger.info(
|
||||
f"[{self.service_name}] Starting RPC server at http://{api_host}:{self.get_port()}"
|
||||
)
|
||||
|
||||
self.http_server = uvicorn.Server(
|
||||
server = uvicorn.Server(
|
||||
uvicorn.Config(
|
||||
self.fastapi_app,
|
||||
host=api_host,
|
||||
@@ -304,94 +270,18 @@ class AppService(BaseAppService, ABC):
|
||||
log_level=self.log_level,
|
||||
)
|
||||
)
|
||||
self.run_and_wait(self.http_server.serve())
|
||||
|
||||
# Perform clean-up when the server exits
|
||||
if not self._cleaned_up:
|
||||
self._cleaned_up = True
|
||||
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] ✅ Cleanup done")
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
"""Pass SIGTERM to Uvicorn so it can shut down gracefully"""
|
||||
signame = signal.Signals(signum).name
|
||||
if not self._shutting_down:
|
||||
self._shutting_down = True
|
||||
if self.http_server:
|
||||
logger.info(
|
||||
f"[{self.service_name}] 🛑 Received {signame} ({signum}) - "
|
||||
"Entering RPC server graceful shutdown"
|
||||
)
|
||||
self.http_server.handle_exit(signum, frame) # stop accepting requests
|
||||
|
||||
# NOTE: Actually stopping the process is triggered by:
|
||||
# 1. The call to self.cleanup() at the end of __start_fastapi() 👆🏼
|
||||
# 2. BaseAppService.cleanup() stopping the shared event loop
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] {signame} received before HTTP server init."
|
||||
" Terminating..."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
else:
|
||||
# Expedite shutdown on second SIGTERM
|
||||
logger.info(
|
||||
f"[{self.service_name}] 🛑🛑 Received {signame} ({signum}), "
|
||||
"but shutdown is already underway. Terminating..."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: FastAPI):
|
||||
"""
|
||||
The FastAPI/Uvicorn server's lifespan manager, used for setup and shutdown.
|
||||
|
||||
You can extend and use this in a subclass like:
|
||||
```
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: FastAPI):
|
||||
async with super().lifespan(app):
|
||||
await db.connect()
|
||||
yield
|
||||
await db.disconnect()
|
||||
```
|
||||
"""
|
||||
# Startup - this runs before Uvicorn starts accepting connections
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown - this runs when FastAPI/Uvicorn shuts down
|
||||
logger.info(f"[{self.service_name}] ✅ FastAPI has finished")
|
||||
self.shared_event_loop.run_until_complete(server.serve())
|
||||
|
||||
async def health_check(self) -> str:
|
||||
"""A method to check the health of the process."""
|
||||
"""
|
||||
A method to check the health of the process.
|
||||
"""
|
||||
return "OK"
|
||||
|
||||
def run(self):
|
||||
sentry_init()
|
||||
super().run()
|
||||
|
||||
self.fastapi_app = FastAPI(lifespan=self.lifespan)
|
||||
|
||||
# Add Prometheus instrumentation to all services
|
||||
try:
|
||||
instrument_fastapi(
|
||||
self.fastapi_app,
|
||||
service_name=self.service_name,
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=False,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
f"Prometheus instrumentation not available for {self.service_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to instrument {self.service_name} with Prometheus: {e}"
|
||||
)
|
||||
self.fastapi_app = FastAPI()
|
||||
|
||||
# Register the exposed API routes.
|
||||
for attr_name, attr in vars(type(self)).items():
|
||||
@@ -416,11 +306,7 @@ class AppService(BaseAppService, ABC):
|
||||
)
|
||||
|
||||
# Start the FastAPI server in a separate thread.
|
||||
api_thread = threading.Thread(
|
||||
target=self.__start_fastapi,
|
||||
daemon=True,
|
||||
name=f"{self.service_name}-http-server",
|
||||
)
|
||||
api_thread = threading.Thread(target=self.__start_fastapi, daemon=True)
|
||||
api_thread.start()
|
||||
|
||||
# Run the main service loop (blocking).
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user