mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
72 Commits
feat/impro
...
fix/sql-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
224411abd3 | ||
|
|
6b241af79e | ||
|
|
320fb7d83a | ||
|
|
54552248f7 | ||
|
|
d8a5780ea2 | ||
|
|
377657f8a1 | ||
|
|
ff71c940c9 | ||
|
|
9967b3a7ce | ||
|
|
9db443960a | ||
|
|
9316100864 | ||
|
|
cbe0cee0fc | ||
|
|
7cbb1ed859 | ||
|
|
e06e7ff33f | ||
|
|
acb946801b | ||
|
|
48ff225837 | ||
|
|
e2a9923f30 | ||
|
|
39792d517e | ||
|
|
a6a2f71458 | ||
|
|
788b861bb7 | ||
|
|
e203e65dc4 | ||
|
|
bd03697ff2 | ||
|
|
efd37b7a36 | ||
|
|
bb0b45d7f7 | ||
|
|
04df981115 | ||
|
|
d25997b4f2 | ||
|
|
11d55f6055 | ||
|
|
063dc5cf65 | ||
|
|
b7646f3e58 | ||
|
|
0befaf0a47 | ||
|
|
93f58dec5e | ||
|
|
3da595f599 | ||
|
|
e5e60921a3 | ||
|
|
90af8f8e1a | ||
|
|
eba67e0a4b | ||
|
|
47bb89caeb | ||
|
|
271a520afa | ||
|
|
3988057032 | ||
|
|
a6c6e48f00 | ||
|
|
e72ce2f9e7 | ||
|
|
bd7a79a920 | ||
|
|
3f546ae845 | ||
|
|
097a19141d | ||
|
|
c958c95d6b | ||
|
|
3e50cbd2cb | ||
|
|
1b69f1644d | ||
|
|
d9035a233c | ||
|
|
972cbfc3de | ||
|
|
8f861b1bb2 | ||
|
|
fa2731bb8b | ||
|
|
2dc0c97a52 | ||
|
|
0bb2b87c32 | ||
|
|
a1d9b45238 | ||
|
|
29895c290f | ||
|
|
73c0b6899a | ||
|
|
4c853a54d7 | ||
|
|
dfdd632161 | ||
|
|
1ed224d481 | ||
|
|
3b5d919399 | ||
|
|
3c16de22ef | ||
|
|
e4bc728d40 | ||
|
|
2c6d85d15e | ||
|
|
8258338caf | ||
|
|
374f35874c | ||
|
|
e62a56e8ba | ||
|
|
f3f9a60157 | ||
|
|
3ed1c93ec0 | ||
|
|
773f545cfd | ||
|
|
84ad4a9f95 | ||
|
|
8610118ddc | ||
|
|
ebb4ebb025 | ||
|
|
cb532e1c4d | ||
|
|
794aee25ab |
94
.github/copilot-instructions.md
vendored
94
.github/copilot-instructions.md
vendored
@@ -12,6 +12,7 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
@@ -23,15 +24,17 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
@@ -48,6 +51,7 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
@@ -58,6 +62,7 @@ cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
@@ -68,6 +73,7 @@ poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
@@ -81,23 +87,27 @@ pnpm storybook # Start component development server
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
@@ -108,6 +118,7 @@ pnpm storybook # Start component development server
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
@@ -121,6 +132,7 @@ pnpm storybook # Start component development server
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
@@ -136,6 +148,7 @@ pnpm storybook # Start component development server
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
@@ -146,11 +159,13 @@ pnpm storybook # Start component development server
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
@@ -160,6 +175,7 @@ pnpm storybook # Start component development server
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
@@ -167,6 +183,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
@@ -174,13 +191,15 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
@@ -189,6 +208,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
@@ -198,6 +218,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
@@ -205,21 +226,76 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
1. Components in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for component development
|
||||
4. Test user-facing features with Playwright E2E tests
|
||||
5. Update protected routes in middleware when needed
|
||||
|
||||
**📖 Complete Frontend Guide**: See `autogpt_platform/frontend/CONTRIBUTING.md` and `autogpt_platform/frontend/.cursorrules` for comprehensive patterns and conventions.
|
||||
|
||||
**Quick Reference:**
|
||||
|
||||
**Component Structure:**
|
||||
|
||||
- Separate render logic from data/behavior
|
||||
- Structure: `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Exception: Small components (3-4 lines of logic) can be inline
|
||||
- Render-only components can be direct files without folders
|
||||
|
||||
**Data Fetching:**
|
||||
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Generated via Orval from backend OpenAPI spec
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
- Example: `useGetV2ListLibraryAgents`
|
||||
- Regenerate with: `pnpm generate:api`
|
||||
- **Never** use deprecated `BackendAPI` or `src/lib/autogpt-server-api/*`
|
||||
|
||||
**Code Conventions:**
|
||||
|
||||
- Use function declarations for components and handlers (not arrow functions)
|
||||
- Only arrow functions for small inline lambdas (map, filter, etc.)
|
||||
- Components: `PascalCase`, Hooks: `camelCase` with `use` prefix
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Minimal comments (code should be self-documenting)
|
||||
|
||||
**Styling:**
|
||||
|
||||
- Use Tailwind CSS utilities only
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
- Only use Phosphor Icons (`@phosphor-icons/react`)
|
||||
- Prefer design tokens over hardcoded values
|
||||
|
||||
**Error Handling:**
|
||||
|
||||
- Render errors: Use `<ErrorCard />` component
|
||||
- Mutation errors: Display with toast notifications
|
||||
- Manual exceptions: Use `Sentry.captureException()`
|
||||
- Global error boundaries already configured
|
||||
|
||||
**Testing:**
|
||||
|
||||
- Add/update Storybook stories for UI components (`pnpm storybook`)
|
||||
- Run Playwright E2E tests with `pnpm test`
|
||||
- Verify in Chromatic after PR
|
||||
|
||||
**Architecture:**
|
||||
|
||||
- Default to client components ("use client")
|
||||
- Server components only for SEO or extreme TTFB needs
|
||||
- Use React Query for server state (via generated hooks)
|
||||
- Co-locate UI state in components/hooks
|
||||
|
||||
### Security Guidelines
|
||||
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
|
||||
The repository has comprehensive CI workflows that test:
|
||||
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
@@ -229,6 +305,7 @@ Match these patterns when developing locally - the copilot setup environment mir
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
@@ -237,8 +314,9 @@ This repository is actively developed with assistance from Claude (via CLAUDE.md
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
|
||||
@@ -63,6 +63,9 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
@@ -75,12 +78,23 @@ pnpm storybook
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
|
||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||
|
||||
**Key Frontend Conventions:**
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
@@ -95,11 +109,16 @@ We have a components library in autogpt_platform/frontend/src/components/atoms t
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: Radix UI primitives with Tailwind CSS styling
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
### Key Concepts
|
||||
|
||||
@@ -153,6 +172,7 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
**Adding a new block:**
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
@@ -160,6 +180,7 @@ Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
@@ -180,10 +201,20 @@ ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
4. Test with Playwright if user-facing
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
### Security Implementation
|
||||
|
||||
|
||||
@@ -8,6 +8,11 @@ start-core:
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
|
||||
reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
docker compose logs -f deps
|
||||
@@ -35,13 +40,18 @@ run-backend:
|
||||
run-frontend:
|
||||
cd frontend && pnpm dev
|
||||
|
||||
test-data:
|
||||
cd backend && poetry run python test/test_data_creator.py
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " reset-db - Reset the database by deleting the volume"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@@ -94,42 +94,36 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
config = LoggingConfig()
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
structured_logging = config.enable_cloud_logging or force_cloud_logging
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
if not structured_logging:
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports import (
|
||||
BackgroundThreadTransport,
|
||||
)
|
||||
else:
|
||||
# Use Google Cloud Structured Log Handler. Log entries are printed to stdout
|
||||
# in a JSON format which is automatically picked up by Google Cloud Logging.
|
||||
from google.cloud.logging.handlers import StructuredLogHandler
|
||||
|
||||
client = google.cloud.logging.Client()
|
||||
# Use BackgroundThreadTransport to prevent blocking the main thread
|
||||
# and deadlocks when gRPC calls to Google Cloud Logging hang
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=BackgroundThreadTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
log_handlers.append(cloud_handler)
|
||||
structured_log_handler = StructuredLogHandler(stream=sys.stdout)
|
||||
structured_log_handler.setLevel(config.level)
|
||||
log_handlers.append(structured_log_handler)
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
@@ -185,7 +179,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
format=DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT,
|
||||
format=(
|
||||
"%(levelname)s %(message)s"
|
||||
if structured_logging
|
||||
else (
|
||||
DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT
|
||||
)
|
||||
),
|
||||
level=config.level,
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
@@ -1,339 +0,0 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Convert args and kwargs into a hashable cache key.
|
||||
|
||||
Handles unhashable types like dict, list, set by converting them to
|
||||
their sorted string representations.
|
||||
"""
|
||||
|
||||
def make_hashable(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a hashable representation."""
|
||||
if isinstance(obj, dict):
|
||||
# Sort dict items to ensure consistent ordering
|
||||
return (
|
||||
"__dict__",
|
||||
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
|
||||
)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return ("__list__", tuple(make_hashable(item) for item in obj))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
|
||||
else:
|
||||
# For basic hashable types (str, int, bool, None, etc.)
|
||||
try:
|
||||
hash(obj)
|
||||
return obj
|
||||
except TypeError:
|
||||
# Fallback: convert to string representation
|
||||
return ("__str__", str(obj))
|
||||
|
||||
hashable_args = tuple(make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
|
||||
return (hashable_args, hashable_kwargs)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for cached functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
|
||||
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
|
||||
return False
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
|
||||
Args:
|
||||
func: The function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
@cache() # Default: maxsize=128, no TTL
|
||||
def expensive_sync_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cache() # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cache(maxsize=1000, ttl_seconds=300) # Custom maxsize and TTL
|
||||
def another_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(target_func):
|
||||
# Cache storage and per-event-loop locks
|
||||
cache_storage = {}
|
||||
_event_loop_locks = {} # Maps event loop to its asyncio.Lock
|
||||
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
|
||||
def _get_cache_lock():
|
||||
"""Get or create an asyncio.Lock for the current event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop, use None as default key
|
||||
loop = None
|
||||
|
||||
if loop not in _event_loop_locks:
|
||||
return _event_loop_locks.setdefault(loop, asyncio.Lock())
|
||||
return _event_loop_locks[loop]
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with _get_cache_lock():
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = async_wrapper
|
||||
|
||||
else:
|
||||
# Sync function with threading.Lock
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Add cache management methods
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
setattr(wrapper, "cache_delete", cache_delete)
|
||||
|
||||
return cast(CachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
|
||||
Each thread gets its own cache, which is useful for request-scoped caching
|
||||
in web applications where you want to cache within a single request but
|
||||
not across requests.
|
||||
|
||||
Args:
|
||||
func: The function to cache
|
||||
|
||||
Returns:
|
||||
Decorated function with thread-local caching
|
||||
|
||||
Example:
|
||||
@thread_cached
|
||||
def expensive_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@thread_cached # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = await func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
"""Clear thread-local cache for a function."""
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
@@ -47,6 +47,7 @@ RUN poetry install --no-ansi --no-root
|
||||
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
@@ -92,6 +93,7 @@ FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
@@ -5,7 +5,7 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -4,13 +4,13 @@ import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._auth import (
|
||||
@@ -114,10 +114,9 @@ class ReadDiscordMessagesBlock(Block):
|
||||
if message.attachments:
|
||||
attachment = message.attachments[0] # Process the first attachment
|
||||
if attachment.filename.endswith((".txt", ".py")):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(attachment.url) as response:
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
response = await Requests().get(attachment.url)
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
|
||||
await client.close()
|
||||
|
||||
@@ -699,16 +698,15 @@ class SendDiscordFileBlock(Block):
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL - download the file
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(file) as response:
|
||||
file_bytes = await response.read()
|
||||
response = await Requests().get(file)
|
||||
file_bytes = response.content
|
||||
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
else:
|
||||
# Local file path - read from stored media file
|
||||
# This would be a path from a previous block's output
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import json
|
||||
from backend.util.json import loads
|
||||
|
||||
|
||||
class StepThroughItemsBlock(Block):
|
||||
@@ -68,7 +68,7 @@ class StepThroughItemsBlock(Block):
|
||||
raise ValueError(
|
||||
f"Input too large: {len(data)} bytes > {MAX_ITEM_SIZE} bytes"
|
||||
)
|
||||
items = json.loads(data)
|
||||
items = loads(data)
|
||||
else:
|
||||
items = data
|
||||
|
||||
|
||||
@@ -62,10 +62,10 @@ TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
|
||||
title="Mock Linear API key",
|
||||
username="mock-linear-username",
|
||||
access_token=SecretStr("mock-linear-access-token"),
|
||||
access_token_expires_at=None,
|
||||
access_token_expires_at=1672531200, # Mock expiration time for short-lived token
|
||||
refresh_token=SecretStr("mock-linear-refresh-token"),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=["mock-linear-scopes"],
|
||||
scopes=["read", "write"],
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_API_KEY = APIKeyCredentials(
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
Linear OAuth handler implementation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@@ -38,8 +40,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://linear.app/oauth/authorize"
|
||||
self.token_url = "https://api.linear.app/oauth/token" # Correct token URL
|
||||
self.token_url = "https://api.linear.app/oauth/token"
|
||||
self.revoke_url = "https://api.linear.app/oauth/revoke"
|
||||
self.migrate_url = "https://api.linear.app/oauth/migrate_old_token"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
@@ -82,19 +85,84 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
return True # Linear doesn't return JSON on successful revoke
|
||||
|
||||
async def migrate_old_token(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""
|
||||
Migrate an old long-lived token to a new short-lived token with refresh token.
|
||||
|
||||
This uses Linear's /oauth/migrate_old_token endpoint to exchange current
|
||||
long-lived tokens for short-lived tokens with refresh tokens without
|
||||
requiring users to re-authorize.
|
||||
"""
|
||||
if not credentials.access_token:
|
||||
raise ValueError("No access token to migrate")
|
||||
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
self.migrate_url, data=request_body, headers=headers
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("error", "Unknown error")
|
||||
error_description = error_data.get("error_description", "")
|
||||
if error_description:
|
||||
error_message = f"{error_message}: {error_description}"
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
raise LinearAPIException(
|
||||
f"Failed to migrate Linear token ({response.status}): {error_message}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
# Extract token expiration
|
||||
now = int(time.time())
|
||||
expires_in = token_data.get("expires_in")
|
||||
access_token_expires_at = None
|
||||
if expires_in:
|
||||
access_token_expires_at = now + expires_in
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
username=credentials.username,
|
||||
access_token=token_data["access_token"],
|
||||
scopes=credentials.scopes, # Preserve original scopes
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
access_token_expires_at=access_token_expires_at,
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
new_credentials.id = credentials.id
|
||||
return new_credentials
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError(
|
||||
"No refresh token available."
|
||||
) # Linear uses non-expiring tokens
|
||||
"No refresh token available. Token may need to be migrated to the new refresh token system."
|
||||
)
|
||||
|
||||
return await self._request_tokens(
|
||||
{
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
},
|
||||
current_credentials=credentials,
|
||||
)
|
||||
|
||||
async def _request_tokens(
|
||||
@@ -102,16 +170,33 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
params: dict[str, str],
|
||||
current_credentials: Optional[OAuth2Credentials] = None,
|
||||
) -> OAuth2Credentials:
|
||||
# Determine if this is a refresh token request
|
||||
is_refresh = params.get("grant_type") == "refresh_token"
|
||||
|
||||
# Build request body with appropriate grant_type
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"grant_type": "authorization_code", # Ensure grant_type is correct
|
||||
**params,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded"
|
||||
} # Correct header for token request
|
||||
# Set default grant_type if not provided
|
||||
if "grant_type" not in request_body:
|
||||
request_body["grant_type"] = "authorization_code"
|
||||
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
# For refresh token requests, support HTTP Basic Authentication as recommended
|
||||
if is_refresh:
|
||||
# Option 1: Use HTTP Basic Auth (preferred by Linear)
|
||||
client_credentials = f"{self.client_id}:{self.client_secret}"
|
||||
encoded_credentials = base64.b64encode(client_credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
|
||||
# Remove client credentials from body when using Basic Auth
|
||||
request_body.pop("client_id", None)
|
||||
request_body.pop("client_secret", None)
|
||||
|
||||
response = await Requests().post(
|
||||
self.token_url, data=request_body, headers=headers
|
||||
)
|
||||
@@ -120,6 +205,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("error", "Unknown error")
|
||||
error_description = error_data.get("error_description", "")
|
||||
if error_description:
|
||||
error_message = f"{error_message}: {error_description}"
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
raise LinearAPIException(
|
||||
@@ -129,27 +217,84 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
# Note: Linear access tokens do not expire, so we set expires_at to None
|
||||
# Extract token expiration if provided (for new refresh token implementation)
|
||||
now = int(time.time())
|
||||
expires_in = token_data.get("expires_in")
|
||||
access_token_expires_at = None
|
||||
if expires_in:
|
||||
access_token_expires_at = now + expires_in
|
||||
|
||||
# Get username - preserve from current credentials if refreshing
|
||||
username = None
|
||||
if current_credentials and is_refresh:
|
||||
username = current_credentials.username
|
||||
elif "user" in token_data:
|
||||
username = token_data["user"].get("name", "Unknown User")
|
||||
else:
|
||||
# Fetch username using the access token
|
||||
username = await self._request_username(token_data["access_token"])
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=current_credentials.title if current_credentials else None,
|
||||
username=token_data.get("user", {}).get(
|
||||
"name", "Unknown User"
|
||||
), # extract name or set appropriate
|
||||
username=username or "Unknown User",
|
||||
access_token=token_data["access_token"],
|
||||
scopes=token_data["scope"].split(
|
||||
","
|
||||
), # Linear returns comma-separated scopes
|
||||
refresh_token=token_data.get(
|
||||
"refresh_token"
|
||||
), # Linear uses non-expiring tokens so this might be null
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=(
|
||||
token_data["scope"].split(",")
|
||||
if "scope" in token_data
|
||||
else (current_credentials.scopes if current_credentials else [])
|
||||
),
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
access_token_expires_at=access_token_expires_at,
|
||||
refresh_token_expires_at=None, # Linear doesn't provide refresh token expiration
|
||||
)
|
||||
|
||||
if current_credentials:
|
||||
new_credentials.id = current_credentials.id
|
||||
|
||||
return new_credentials
|
||||
|
||||
async def get_access_token(self, credentials: OAuth2Credentials) -> str:
|
||||
"""
|
||||
Returns a valid access token, handling migration and refresh as needed.
|
||||
|
||||
This overrides the base implementation to handle Linear's token migration
|
||||
from old long-lived tokens to new short-lived tokens with refresh tokens.
|
||||
"""
|
||||
# If token has no expiration and no refresh token, it might be an old token
|
||||
# that needs migration
|
||||
if (
|
||||
credentials.access_token_expires_at is None
|
||||
and credentials.refresh_token is None
|
||||
):
|
||||
try:
|
||||
# Attempt to migrate the old token
|
||||
migrated_credentials = await self.migrate_old_token(credentials)
|
||||
# Update the credentials store would need to be handled by the caller
|
||||
# For now, use the migrated credentials for this request
|
||||
credentials = migrated_credentials
|
||||
except LinearAPIException:
|
||||
# Migration failed, try to use the old token as-is
|
||||
# This maintains backward compatibility
|
||||
pass
|
||||
|
||||
# Use the standard refresh logic from the base class
|
||||
if self.needs_refresh(credentials):
|
||||
credentials = await self.refresh_tokens(credentials)
|
||||
|
||||
return credentials.access_token.get_secret_value()
|
||||
|
||||
def needs_migration(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""
|
||||
Check if credentials represent an old long-lived token that needs migration.
|
||||
|
||||
Old tokens have no expiration time and no refresh token.
|
||||
"""
|
||||
return (
|
||||
credentials.access_token_expires_at is None
|
||||
and credentials.refresh_token is None
|
||||
)
|
||||
|
||||
async def _request_username(self, access_token: str) -> Optional[str]:
|
||||
# Use the LinearClient to fetch user details using GraphQL
|
||||
from ._api import LinearClient
|
||||
|
||||
@@ -104,8 +104,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -224,12 +222,6 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-5-sonnet-20241022
|
||||
LlmModel.CLAUDE_3_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-3-5-haiku-20241022
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096
|
||||
), # claude-3-haiku-20240307
|
||||
@@ -1562,7 +1554,9 @@ class AIConversationBlock(AIBlockBase):
|
||||
("prompt", list),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
|
||||
"llm_call": lambda *args, **kwargs: dict(
|
||||
response="The 2020 World Series was played at Globe Life Field in Arlington, Texas."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1591,7 +1585,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
),
|
||||
credentials=credentials,
|
||||
)
|
||||
yield "response", response
|
||||
yield "response", response["response"]
|
||||
yield "prompt", self.prompt
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
@@ -10,6 +8,7 @@ import pydantic
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class RSSEntry(pydantic.BaseModel):
|
||||
@@ -103,35 +102,29 @@ class ReadRSSFeedBlock(Block):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_feed(url: str) -> dict[str, Any]:
|
||||
async def parse_feed(url: str) -> dict[str, Any]:
|
||||
# Security fix: Add protection against memory exhaustion attacks
|
||||
MAX_FEED_SIZE = 10 * 1024 * 1024 # 10MB limit for RSS feeds
|
||||
|
||||
# Validate URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError(f"Invalid URL scheme: {parsed_url.scheme}")
|
||||
|
||||
# Download with size limit
|
||||
# Download feed content with size limit
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=30) as response:
|
||||
# Check content length if available
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
|
||||
)
|
||||
response = await Requests(raise_for_status=True).get(url)
|
||||
|
||||
# Read with size limit
|
||||
content = response.read(MAX_FEED_SIZE + 1)
|
||||
if len(content) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit"
|
||||
)
|
||||
# Check content length if available
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > MAX_FEED_SIZE:
|
||||
raise ValueError(
|
||||
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
|
||||
)
|
||||
|
||||
# Parse with feedparser using the validated content
|
||||
# feedparser has built-in protection against XML attacks
|
||||
return feedparser.parse(content) # type: ignore
|
||||
# Get content with size limit
|
||||
content = response.content
|
||||
if len(content) > MAX_FEED_SIZE:
|
||||
raise ValueError(f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit")
|
||||
|
||||
# Parse with feedparser using the validated content
|
||||
# feedparser has built-in protection against XML attacks
|
||||
return feedparser.parse(content) # type: ignore
|
||||
except Exception as e:
|
||||
# Log error and return empty feed
|
||||
logging.warning(f"Failed to parse RSS feed from {url}: {e}")
|
||||
@@ -145,7 +138,7 @@ class ReadRSSFeedBlock(Block):
|
||||
while keep_going:
|
||||
keep_going = input_data.run_continuously
|
||||
|
||||
feed = self.parse_feed(input_data.rss_url)
|
||||
feed = await self.parse_feed(input_data.rss_url)
|
||||
all_entries = []
|
||||
|
||||
for entry in feed["entries"]:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
@@ -26,6 +27,13 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="coroutine 'close_litellm_async_clients' was never awaited",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
@@ -362,7 +362,7 @@ class TestLLMStatsTracking:
|
||||
assert block.execution_stats.llm_call_count == 1
|
||||
|
||||
# Check output
|
||||
assert outputs["response"] == {"response": "AI response to conversation"}
|
||||
assert outputs["response"] == "AI response to conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_list_generator_with_retries(self):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from youtube_transcript_api._api import YouTubeTranscriptApi
|
||||
from youtube_transcript_api._errors import NoTranscriptFound
|
||||
from youtube_transcript_api._transcripts import FetchedTranscript
|
||||
from youtube_transcript_api.formatters import TextFormatter
|
||||
|
||||
@@ -64,7 +65,29 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def get_transcript(video_id: str) -> FetchedTranscript:
|
||||
return YouTubeTranscriptApi().fetch(video_id=video_id)
|
||||
"""
|
||||
Get transcript for a video, preferring English but falling back to any available language.
|
||||
|
||||
:param video_id: The YouTube video ID
|
||||
:return: The fetched transcript
|
||||
:raises: Any exception except NoTranscriptFound for requested languages
|
||||
"""
|
||||
api = YouTubeTranscriptApi()
|
||||
try:
|
||||
# Try to get English transcript first (default behavior)
|
||||
return api.fetch(video_id=video_id)
|
||||
except NoTranscriptFound:
|
||||
# If English is not available, get the first available transcript
|
||||
transcript_list = api.list(video_id)
|
||||
# Try manually created transcripts first, then generated ones
|
||||
available_transcripts = list(
|
||||
transcript_list._manually_created_transcripts.values()
|
||||
) + list(transcript_list._generated_transcripts.values())
|
||||
if available_transcripts:
|
||||
# Fetch the first available transcript
|
||||
return available_transcripts[0].fetch()
|
||||
# If no transcripts at all, re-raise the original error
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def format_transcript(transcript: FetchedTranscript) -> str:
|
||||
|
||||
@@ -45,9 +45,6 @@ class MainApp(AppProcess):
|
||||
|
||||
app.main(silent=True)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import (
|
||||
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
@@ -28,6 +27,7 @@ from pydantic import BaseModel
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.cache import cached
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .model import (
|
||||
@@ -722,7 +722,7 @@ def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
return cls() if cls else None
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_webhook_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
@@ -731,7 +731,7 @@ def get_webhook_block_ids() -> Sequence[str]:
|
||||
]
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_io_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from typing import Type
|
||||
|
||||
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
|
||||
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.ai_shortform_video_block import (
|
||||
AIAdMakerVideoCreatorBlock,
|
||||
AIScreenshotToVideoAdBlock,
|
||||
AIShortformVideoCreatorBlock,
|
||||
)
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.apollo.person import GetPersonDetailBlock
|
||||
@@ -72,8 +76,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
@@ -323,7 +325,31 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
],
|
||||
AIShortformVideoCreatorBlock: [
|
||||
BlockCost(
|
||||
cost_amount=50,
|
||||
cost_amount=307,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": revid_credentials.id,
|
||||
"provider": revid_credentials.provider,
|
||||
"type": revid_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
AIAdMakerVideoCreatorBlock: [
|
||||
BlockCost(
|
||||
cost_amount=714,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": revid_credentials.id,
|
||||
"provider": revid_credentials.provider,
|
||||
"type": revid_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
AIScreenshotToVideoAdBlock: [
|
||||
BlockCost(
|
||||
cost_amount=612,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": revid_credentials.id,
|
||||
|
||||
@@ -5,7 +5,6 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import stripe
|
||||
from prisma import Json
|
||||
from prisma.enums import (
|
||||
CreditRefundRequestStatus,
|
||||
CreditTransactionType,
|
||||
@@ -13,16 +12,12 @@ from prisma.enums import (
|
||||
OnboardingStep,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User
|
||||
from prisma.types import (
|
||||
CreditRefundRequestCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
CreditTransactionWhereInput,
|
||||
)
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
@@ -36,7 +31,8 @@ from backend.data.user import get_user_by_id, get_user_email_by_id
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.admin.model import UserHistoryResponse
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.json import SafeJson, dumps
|
||||
from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
@@ -49,6 +45,10 @@ stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
# Constants for test compatibility
|
||||
POSTGRES_INT_MAX = 2147483647
|
||||
POSTGRES_INT_MIN = -2147483648
|
||||
|
||||
|
||||
class UsageTransactionMetadata(BaseModel):
|
||||
graph_exec_id: str | None = None
|
||||
@@ -139,14 +139,20 @@ class UserCreditBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
|
||||
async def onboarding_reward(
|
||||
self, user_id: str, credits: int, step: OnboardingStep
|
||||
) -> bool:
|
||||
"""
|
||||
Reward the user with credits for completing an onboarding step.
|
||||
Won't reward if the user has already received credits for the step.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
credits (int): The amount to reward.
|
||||
step (OnboardingStep): The onboarding step.
|
||||
|
||||
Returns:
|
||||
bool: True if rewarded, False if already rewarded.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -236,6 +242,12 @@ class UserCreditBase(ABC):
|
||||
"""
|
||||
Returns the current balance of the user & the latest balance snapshot time.
|
||||
"""
|
||||
# Check UserBalance first for efficiency and consistency
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
if user_balance:
|
||||
return user_balance.balance, user_balance.updatedAt
|
||||
|
||||
# Fallback to transaction history computation if UserBalance doesn't exist
|
||||
top_time = self.time_now()
|
||||
snapshot = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
@@ -250,72 +262,86 @@ class UserCreditBase(ABC):
|
||||
snapshot_balance = snapshot.runningBalance or 0 if snapshot else 0
|
||||
snapshot_time = snapshot.createdAt if snapshot else datetime_min
|
||||
|
||||
# Get transactions after the snapshot, this should not exist, but just in case.
|
||||
transactions = await CreditTransaction.prisma().group_by(
|
||||
by=["userId"],
|
||||
sum={"amount": True},
|
||||
max={"createdAt": True},
|
||||
where={
|
||||
"userId": user_id,
|
||||
"createdAt": {
|
||||
"gt": snapshot_time,
|
||||
"lte": top_time,
|
||||
},
|
||||
"isActive": True,
|
||||
},
|
||||
)
|
||||
transaction_balance = (
|
||||
int(transactions[0].get("_sum", {}).get("amount", 0) + snapshot_balance)
|
||||
if transactions
|
||||
else snapshot_balance
|
||||
)
|
||||
transaction_time = (
|
||||
datetime.fromisoformat(
|
||||
str(transactions[0].get("_max", {}).get("createdAt", datetime_min))
|
||||
)
|
||||
if transactions
|
||||
else snapshot_time
|
||||
)
|
||||
return transaction_balance, transaction_time
|
||||
return snapshot_balance, snapshot_time
|
||||
|
||||
@func_retry
|
||||
async def _enable_transaction(
|
||||
self,
|
||||
transaction_key: str,
|
||||
user_id: str,
|
||||
metadata: Json,
|
||||
metadata: SafeJson,
|
||||
new_transaction_key: str | None = None,
|
||||
):
|
||||
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": transaction_key, "userId": user_id}
|
||||
# First check if transaction exists and is inactive (safety check)
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"transactionKey": transaction_key,
|
||||
"userId": user_id,
|
||||
"isActive": False,
|
||||
}
|
||||
)
|
||||
if transaction.isActive:
|
||||
return
|
||||
if not transaction:
|
||||
# Transaction doesn't exist or is already active, return early
|
||||
return None
|
||||
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}"):
|
||||
|
||||
transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": transaction_key, "userId": user_id}
|
||||
# Atomic operation to enable transaction and update user balance using UserBalance
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
WITH user_balance_lock AS (
|
||||
SELECT
|
||||
$2::text as userId,
|
||||
COALESCE(
|
||||
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $2 FOR UPDATE),
|
||||
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
|
||||
(SELECT COALESCE(ct."runningBalance", 0)
|
||||
FROM {schema_prefix}"CreditTransaction" ct
|
||||
WHERE ct."userId" = $2
|
||||
AND ct."isActive" = true
|
||||
AND ct."runningBalance" IS NOT NULL
|
||||
ORDER BY ct."createdAt" DESC
|
||||
LIMIT 1),
|
||||
0
|
||||
) as balance
|
||||
),
|
||||
transaction_check AS (
|
||||
SELECT * FROM {schema_prefix}"CreditTransaction"
|
||||
WHERE "transactionKey" = $1 AND "userId" = $2 AND "isActive" = false
|
||||
),
|
||||
balance_update AS (
|
||||
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
|
||||
SELECT
|
||||
$2::text,
|
||||
user_balance_lock.balance + transaction_check.amount,
|
||||
CURRENT_TIMESTAMP
|
||||
FROM user_balance_lock, transaction_check
|
||||
ON CONFLICT ("userId") DO UPDATE SET
|
||||
"balance" = EXCLUDED."balance",
|
||||
"updatedAt" = EXCLUDED."updatedAt"
|
||||
RETURNING "balance", "updatedAt"
|
||||
),
|
||||
transaction_update AS (
|
||||
UPDATE {schema_prefix}"CreditTransaction"
|
||||
SET "transactionKey" = COALESCE($4, $1),
|
||||
"isActive" = true,
|
||||
"runningBalance" = balance_update.balance,
|
||||
"createdAt" = balance_update."updatedAt",
|
||||
"metadata" = $3::jsonb
|
||||
FROM balance_update, transaction_check
|
||||
WHERE {schema_prefix}"CreditTransaction"."transactionKey" = transaction_check."transactionKey"
|
||||
AND {schema_prefix}"CreditTransaction"."userId" = transaction_check."userId"
|
||||
RETURNING {schema_prefix}"CreditTransaction"."runningBalance"
|
||||
)
|
||||
if transaction.isActive:
|
||||
return
|
||||
SELECT "runningBalance" as balance FROM transaction_update;
|
||||
""",
|
||||
transaction_key, # $1
|
||||
user_id, # $2
|
||||
dumps(metadata.data), # $3 - use pre-serialized JSON string for JSONB
|
||||
new_transaction_key, # $4
|
||||
)
|
||||
|
||||
user_balance, _ = await self._get_credits(user_id)
|
||||
await CreditTransaction.prisma().update(
|
||||
where={
|
||||
"creditTransactionIdentifier": {
|
||||
"transactionKey": transaction_key,
|
||||
"userId": user_id,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"transactionKey": new_transaction_key or transaction_key,
|
||||
"isActive": True,
|
||||
"runningBalance": user_balance + transaction.amount,
|
||||
"createdAt": self.time_now(),
|
||||
"metadata": metadata,
|
||||
},
|
||||
)
|
||||
if result:
|
||||
# UserBalance is already updated by the CTE
|
||||
return result[0]["balance"]
|
||||
|
||||
async def _add_transaction(
|
||||
self,
|
||||
@@ -326,12 +352,54 @@ class UserCreditBase(ABC):
|
||||
transaction_key: str | None = None,
|
||||
ceiling_balance: int | None = None,
|
||||
fail_insufficient_credits: bool = True,
|
||||
metadata: Json = SafeJson({}),
|
||||
metadata: SafeJson = SafeJson({}),
|
||||
) -> tuple[int, str]:
|
||||
"""
|
||||
Add a new transaction for the user.
|
||||
This is the only method that should be used to add a new transaction.
|
||||
|
||||
ATOMIC OPERATION DESIGN DECISION:
|
||||
================================
|
||||
This method uses PostgreSQL row-level locking (FOR UPDATE) for atomic credit operations.
|
||||
After extensive analysis of concurrency patterns and correctness requirements, we determined
|
||||
that the FOR UPDATE approach is necessary despite the latency overhead.
|
||||
|
||||
WHY FOR UPDATE LOCKING IS REQUIRED:
|
||||
----------------------------------
|
||||
1. **Data Consistency**: Credit operations must be ACID-compliant. The balance check,
|
||||
calculation, and update must be atomic to prevent race conditions where:
|
||||
- Multiple spend operations could exceed available balance
|
||||
- Lost update problems could occur with concurrent top-ups
|
||||
- Refunds could create negative balances incorrectly
|
||||
|
||||
2. **Serializability**: FOR UPDATE ensures operations are serialized at the database level,
|
||||
guaranteeing that each transaction sees a consistent view of the balance before applying changes.
|
||||
|
||||
3. **Correctness Over Performance**: Financial operations require absolute correctness.
|
||||
The ~10-50ms latency increase from row locking is acceptable for the guarantee that
|
||||
no user will ever have an incorrect balance due to race conditions.
|
||||
|
||||
4. **PostgreSQL Optimization**: Modern PostgreSQL versions optimize row locks efficiently.
|
||||
The performance cost is minimal compared to the complexity and risk of lock-free approaches.
|
||||
|
||||
ALTERNATIVES CONSIDERED AND REJECTED:
|
||||
------------------------------------
|
||||
- **Optimistic Concurrency**: Using version numbers or timestamps would require complex
|
||||
retry logic and could still fail under high contention scenarios.
|
||||
- **Application-Level Locking**: Redis locks or similar would add network overhead and
|
||||
single points of failure while being less reliable than database locks.
|
||||
- **Event Sourcing**: Would require complete architectural changes and eventual consistency
|
||||
models that don't fit our real-time balance requirements.
|
||||
|
||||
PERFORMANCE CHARACTERISTICS:
|
||||
---------------------------
|
||||
- Single user operations: 10-50ms latency (acceptable for financial operations)
|
||||
- Concurrent operations on same user: Serialized (prevents data corruption)
|
||||
- Concurrent operations on different users: Fully parallel (no blocking)
|
||||
|
||||
This design prioritizes correctness and data integrity over raw performance,
|
||||
which is the appropriate choice for a credit/payment system.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
amount (int): The amount of credits to add.
|
||||
@@ -345,40 +413,142 @@ class UserCreditBase(ABC):
|
||||
Returns:
|
||||
tuple[int, str]: The new balance & the transaction key.
|
||||
"""
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}"):
|
||||
# Get latest balance snapshot
|
||||
user_balance, _ = await self._get_credits(user_id)
|
||||
|
||||
if ceiling_balance and amount > 0 and user_balance >= ceiling_balance:
|
||||
# Quick validation for ceiling balance to avoid unnecessary database operations
|
||||
if ceiling_balance and amount > 0:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
if current_balance >= ceiling_balance:
|
||||
raise ValueError(
|
||||
f"You already have enough balance of ${user_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
|
||||
)
|
||||
|
||||
if amount < 0 and user_balance + amount < 0:
|
||||
if fail_insufficient_credits:
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
user_id=user_id,
|
||||
balance=user_balance,
|
||||
amount=amount,
|
||||
# Single unified atomic operation for all transaction types using UserBalance
|
||||
try:
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
WITH user_balance_lock AS (
|
||||
SELECT
|
||||
$1::text as userId,
|
||||
-- CRITICAL: FOR UPDATE lock prevents concurrent modifications to the same user's balance
|
||||
-- This ensures atomic read-modify-write operations and prevents race conditions
|
||||
COALESCE(
|
||||
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $1 FOR UPDATE),
|
||||
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
|
||||
(SELECT COALESCE(ct."runningBalance", 0)
|
||||
FROM {schema_prefix}"CreditTransaction" ct
|
||||
WHERE ct."userId" = $1
|
||||
AND ct."isActive" = true
|
||||
AND ct."runningBalance" IS NOT NULL
|
||||
ORDER BY ct."createdAt" DESC
|
||||
LIMIT 1),
|
||||
0
|
||||
) as balance
|
||||
),
|
||||
balance_update AS (
|
||||
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
|
||||
SELECT
|
||||
$1::text,
|
||||
CASE
|
||||
-- For inactive transactions: Don't update balance
|
||||
WHEN $5::boolean = false THEN user_balance_lock.balance
|
||||
-- For ceiling balance (amount > 0): Apply ceiling
|
||||
WHEN $2 > 0 AND $7::int IS NOT NULL AND user_balance_lock.balance > $7::int - $2 THEN $7::int
|
||||
-- For regular operations: Apply with overflow/underflow protection
|
||||
WHEN user_balance_lock.balance + $2 > $6::int THEN $6::int
|
||||
WHEN user_balance_lock.balance + $2 < $10::int THEN $10::int
|
||||
ELSE user_balance_lock.balance + $2
|
||||
END,
|
||||
CURRENT_TIMESTAMP
|
||||
FROM user_balance_lock
|
||||
WHERE (
|
||||
$5::boolean = false OR -- Allow inactive transactions
|
||||
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
|
||||
$8::boolean = false OR -- Allow when insufficient balance check is disabled
|
||||
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
|
||||
)
|
||||
ON CONFLICT ("userId") DO UPDATE SET
|
||||
"balance" = EXCLUDED."balance",
|
||||
"updatedAt" = EXCLUDED."updatedAt"
|
||||
RETURNING "balance", "updatedAt"
|
||||
),
|
||||
transaction_insert AS (
|
||||
INSERT INTO {schema_prefix}"CreditTransaction" (
|
||||
"userId", "amount", "type", "runningBalance",
|
||||
"metadata", "isActive", "createdAt", "transactionKey"
|
||||
)
|
||||
SELECT
|
||||
$1::text,
|
||||
$2::int,
|
||||
$3::text::{schema_prefix}"CreditTransactionType",
|
||||
CASE
|
||||
-- For inactive transactions: Set runningBalance to original balance (don't apply the change yet)
|
||||
WHEN $5::boolean = false THEN user_balance_lock.balance
|
||||
ELSE COALESCE(balance_update.balance, user_balance_lock.balance)
|
||||
END,
|
||||
$4::jsonb,
|
||||
$5::boolean,
|
||||
COALESCE(balance_update."updatedAt", CURRENT_TIMESTAMP),
|
||||
COALESCE($9, gen_random_uuid()::text)
|
||||
FROM user_balance_lock
|
||||
LEFT JOIN balance_update ON true
|
||||
WHERE (
|
||||
$5::boolean = false OR -- Allow inactive transactions
|
||||
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
|
||||
$8::boolean = false OR -- Allow when insufficient balance check is disabled
|
||||
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
|
||||
)
|
||||
RETURNING "runningBalance", "transactionKey"
|
||||
)
|
||||
SELECT "runningBalance" as balance, "transactionKey" FROM transaction_insert;
|
||||
""",
|
||||
user_id, # $1
|
||||
amount, # $2
|
||||
transaction_type.value, # $3
|
||||
dumps(metadata.data), # $4 - use pre-serialized JSON string for JSONB
|
||||
is_active, # $5
|
||||
POSTGRES_INT_MAX, # $6 - overflow protection
|
||||
ceiling_balance, # $7 - ceiling balance (nullable)
|
||||
fail_insufficient_credits, # $8 - check balance for spending
|
||||
transaction_key, # $9 - transaction key (nullable)
|
||||
POSTGRES_INT_MIN, # $10 - underflow protection
|
||||
)
|
||||
except Exception as e:
|
||||
# Convert raw SQL unique constraint violations to UniqueViolationError
|
||||
# for consistent exception handling throughout the codebase
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"already exists" in error_str
|
||||
or "duplicate key" in error_str
|
||||
or "unique constraint" in error_str
|
||||
):
|
||||
# Extract table and constraint info for better error messages
|
||||
# Re-raise as a UniqueViolationError but with proper format
|
||||
# Create a minimal data structure that the error constructor expects
|
||||
raise UniqueViolationError({"error": str(e), "user_facing_error": {}})
|
||||
# For any other error, re-raise as-is
|
||||
raise
|
||||
|
||||
amount = min(-user_balance, 0)
|
||||
if result:
|
||||
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
|
||||
# UserBalance is already updated by the CTE
|
||||
return new_balance, tx_key
|
||||
|
||||
# Create the transaction
|
||||
transaction_data: CreditTransactionCreateInput = {
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"runningBalance": user_balance + amount,
|
||||
"type": transaction_type,
|
||||
"metadata": metadata,
|
||||
"isActive": is_active,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
if transaction_key:
|
||||
transaction_data["transactionKey"] = transaction_key
|
||||
tx = await CreditTransaction.prisma().create(data=transaction_data)
|
||||
return user_balance + amount, tx.transactionKey
|
||||
# If no result, either user doesn't exist or insufficient balance
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
|
||||
# Must be insufficient balance for spending operation
|
||||
if amount < 0 and fail_insufficient_credits:
|
||||
current_balance, _ = await self._get_credits(user_id)
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
user_id=user_id,
|
||||
balance=current_balance,
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
# Unexpected case
|
||||
raise ValueError(f"Transaction failed for user {user_id}, amount {amount}")
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
@@ -453,9 +623,10 @@ class UserCredit(UserCreditBase):
|
||||
{"reason": f"Reward for completing {step.value} onboarding step."}
|
||||
),
|
||||
)
|
||||
return True
|
||||
except UniqueViolationError:
|
||||
# Already rewarded for this step
|
||||
pass
|
||||
# User already received this reward
|
||||
return False
|
||||
|
||||
async def top_up_refund(
|
||||
self, user_id: str, transaction_key: str, metadata: dict[str, str]
|
||||
@@ -644,7 +815,7 @@ class UserCredit(UserCreditBase):
|
||||
):
|
||||
# init metadata, without sharing it with the world
|
||||
metadata = metadata or {}
|
||||
if not metadata["reason"]:
|
||||
if not metadata.get("reason"):
|
||||
match top_up_type:
|
||||
case TopUpType.MANUAL:
|
||||
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
|
||||
@@ -974,8 +1145,8 @@ class DisabledUserCredit(UserCreditBase):
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def onboarding_reward(self, *args, **kwargs):
|
||||
pass
|
||||
async def onboarding_reward(self, *args, **kwargs) -> bool:
|
||||
return True
|
||||
|
||||
async def top_up_intent(self, *args, **kwargs) -> str:
|
||||
return ""
|
||||
@@ -993,14 +1164,31 @@ class DisabledUserCredit(UserCreditBase):
|
||||
pass
|
||||
|
||||
|
||||
def get_user_credit_model() -> UserCreditBase:
|
||||
async def get_user_credit_model(user_id: str) -> UserCreditBase:
|
||||
"""
|
||||
Get the credit model for a user, considering LaunchDarkly flags.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to check flags for.
|
||||
|
||||
Returns:
|
||||
UserCreditBase: The appropriate credit model for the user
|
||||
"""
|
||||
if not settings.config.enable_credit:
|
||||
return DisabledUserCredit()
|
||||
|
||||
if settings.config.enable_beta_monthly_credit:
|
||||
return BetaUserCredit(settings.config.num_user_credits_refill)
|
||||
# Check LaunchDarkly flag for payment pilot users
|
||||
# Default to False (beta monthly credit behavior) to maintain current behavior
|
||||
is_payment_enabled = await is_feature_enabled(
|
||||
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
|
||||
)
|
||||
|
||||
return UserCredit()
|
||||
if is_payment_enabled:
|
||||
# Payment enabled users get UserCredit (no monthly refills, enable payments)
|
||||
return UserCredit()
|
||||
else:
|
||||
# Default behavior: users get beta monthly credits
|
||||
return BetaUserCredit(settings.config.num_user_credits_refill)
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
@@ -1090,7 +1278,8 @@ async def admin_get_user_history(
|
||||
)
|
||||
reason = metadata.get("reason", "No reason provided")
|
||||
|
||||
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
|
||||
user_credit_model = await get_user_credit_model(tx.userId)
|
||||
balance, _ = await user_credit_model._get_credits(tx.userId)
|
||||
|
||||
history.append(
|
||||
UserTransaction(
|
||||
|
||||
172
autogpt_platform/backend/backend/data/credit_ceiling_test.py
Normal file
172
autogpt_platform/backend/backend/data/credit_ceiling_test.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Test ceiling balance functionality to ensure auto top-up limits work correctly.
|
||||
|
||||
This test was added to cover a previously untested code path that could lead to
|
||||
incorrect balance capping behavior.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for ceiling tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their transactions."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ceiling_balance_rejects_when_above_threshold(server: SpinTestServer):
|
||||
"""Test that ceiling balance correctly rejects top-ups when balance is above threshold."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"ceiling-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user balance of 1000 ($10) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
assert current_balance == 1000
|
||||
|
||||
# Try to add 200 more with ceiling of 800 (should reject since 1000 > 800)
|
||||
with pytest.raises(ValueError, match="You already have enough balance"):
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
ceiling_balance=800, # Ceiling lower than current balance
|
||||
)
|
||||
|
||||
# Balance should remain unchanged
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 1000, f"Balance should remain 1000, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ceiling_balance_clamps_when_would_exceed(server: SpinTestServer):
|
||||
"""Test that ceiling balance correctly clamps amounts that would exceed the ceiling."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"ceiling-clamp-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user balance of 500 ($5) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Add 800 more with ceiling of 1000 (should clamp to 1000, not reach 1300)
|
||||
final_balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=800,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
ceiling_balance=1000, # Ceiling should clamp 500 + 800 = 1300 to 1000
|
||||
)
|
||||
|
||||
# Balance should be clamped to ceiling
|
||||
assert (
|
||||
final_balance == 1000
|
||||
), f"Balance should be clamped to 1000, got {final_balance}"
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == 1000
|
||||
), f"Stored balance should be 1000, got {stored_balance}"
|
||||
|
||||
# Verify transaction shows the clamped amount
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": CreditTransactionType.TOP_UP},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
|
||||
# Should have 2 transactions: 500 + (500 to reach ceiling of 1000)
|
||||
assert len(transactions) == 2
|
||||
|
||||
# The second transaction should show it only added 500, not 800
|
||||
second_tx = transactions[0] # Most recent
|
||||
assert second_tx.runningBalance == 1000
|
||||
# The actual amount recorded could be 800 (what was requested) but balance was clamped
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_ceiling_balance_allows_when_under_threshold(server: SpinTestServer):
|
||||
"""Test that ceiling balance allows top-ups when balance is under threshold."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"ceiling-under-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user balance of 300 ($3) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=300,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Add 200 more with ceiling of 1000 (should succeed: 300 + 200 = 500 < 1000)
|
||||
final_balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
ceiling_balance=1000,
|
||||
)
|
||||
|
||||
# Balance should be exactly 500
|
||||
assert final_balance == 500, f"Balance should be 500, got {final_balance}"
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == 500
|
||||
), f"Stored balance should be 500, got {stored_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
737
autogpt_platform/backend/backend/data/credit_concurrency_test.py
Normal file
737
autogpt_platform/backend/backend/data/credit_concurrency_test.py
Normal file
@@ -0,0 +1,737 @@
|
||||
"""
|
||||
Concurrency and atomicity tests for the credit system.
|
||||
|
||||
These tests ensure the credit system handles high-concurrency scenarios correctly
|
||||
without race conditions, deadlocks, or inconsistent state.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from uuid import uuid4
|
||||
|
||||
import prisma.enums
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
# Test with both UserCredit and BetaUserCredit if needed
|
||||
credit_system = UserCredit()
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user with initial balance."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
# Ensure UserBalance record exists
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their transactions."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_spends_same_user(server: SpinTestServer):
|
||||
"""Test multiple concurrent spends from the same user don't cause race conditions."""
|
||||
user_id = f"concurrent-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user initial balance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Try to spend 10 x $1 concurrently
|
||||
async def spend_one_dollar(idx: int):
|
||||
try:
|
||||
return await credit_system.spend_credits(
|
||||
user_id,
|
||||
100, # $1
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"concurrent-{idx}",
|
||||
reason=f"Concurrent spend {idx}",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return None
|
||||
|
||||
# Run 10 concurrent spends
|
||||
results = await asyncio.gather(
|
||||
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
|
||||
)
|
||||
|
||||
# Count successful spends
|
||||
successful = [
|
||||
r for r in results if r is not None and not isinstance(r, Exception)
|
||||
]
|
||||
failed = [r for r in results if isinstance(r, InsufficientBalanceError)]
|
||||
|
||||
# All 10 should succeed since we have exactly $10
|
||||
assert len(successful) == 10, f"Expected 10 successful, got {len(successful)}"
|
||||
assert len(failed) == 0, f"Expected 0 failures, got {len(failed)}"
|
||||
|
||||
# Final balance should be exactly 0
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
|
||||
|
||||
# Verify transaction history is consistent
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE}
|
||||
)
|
||||
assert (
|
||||
len(transactions) == 10
|
||||
), f"Expected 10 transactions, got {len(transactions)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_spends_insufficient_balance(server: SpinTestServer):
|
||||
"""Test that concurrent spends correctly enforce balance limits."""
|
||||
user_id = f"insufficient-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user limited balance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "limited_balance"}),
|
||||
)
|
||||
|
||||
# Try to spend 10 x $1 concurrently (but only have $5)
|
||||
async def spend_one_dollar(idx: int):
|
||||
try:
|
||||
return await credit_system.spend_credits(
|
||||
user_id,
|
||||
100, # $1
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"insufficient-{idx}",
|
||||
reason=f"Insufficient spend {idx}",
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return "FAILED"
|
||||
|
||||
# Run 10 concurrent spends
|
||||
results = await asyncio.gather(
|
||||
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
|
||||
)
|
||||
|
||||
# Count successful vs failed
|
||||
successful = [
|
||||
r
|
||||
for r in results
|
||||
if r not in ["FAILED", None] and not isinstance(r, Exception)
|
||||
]
|
||||
failed = [r for r in results if r == "FAILED"]
|
||||
|
||||
# Exactly 5 should succeed, 5 should fail
|
||||
assert len(successful) == 5, f"Expected 5 successful, got {len(successful)}"
|
||||
assert len(failed) == 5, f"Expected 5 failures, got {len(failed)}"
|
||||
|
||||
# Final balance should be exactly 0
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_mixed_operations(server: SpinTestServer):
|
||||
"""Test concurrent mix of spends, top-ups, and balance checks."""
|
||||
user_id = f"mixed-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Initial balance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "initial_balance"}),
|
||||
)
|
||||
|
||||
# Mix of operations
|
||||
async def mixed_operations():
|
||||
operations = []
|
||||
|
||||
# 5 spends of $1 each
|
||||
for i in range(5):
|
||||
operations.append(
|
||||
credit_system.spend_credits(
|
||||
user_id,
|
||||
100,
|
||||
UsageTransactionMetadata(reason=f"Mixed spend {i}"),
|
||||
)
|
||||
)
|
||||
|
||||
# 3 top-ups of $2 each using internal method
|
||||
for i in range(3):
|
||||
operations.append(
|
||||
credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": f"concurrent_topup_{i}"}),
|
||||
)
|
||||
)
|
||||
|
||||
# 10 balance checks
|
||||
for i in range(10):
|
||||
operations.append(credit_system.get_credits(user_id))
|
||||
|
||||
return await asyncio.gather(*operations, return_exceptions=True)
|
||||
|
||||
results = await mixed_operations()
|
||||
|
||||
# Check no exceptions occurred
|
||||
exceptions = [
|
||||
r
|
||||
for r in results
|
||||
if isinstance(r, Exception) and not isinstance(r, InsufficientBalanceError)
|
||||
]
|
||||
assert len(exceptions) == 0, f"Unexpected exceptions: {exceptions}"
|
||||
|
||||
# Final balance should be: 1000 - 500 + 600 = 1100
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 1100, f"Expected balance 1100, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_race_condition_exact_balance(server: SpinTestServer):
|
||||
"""Test spending exact balance amount concurrently doesn't go negative."""
|
||||
user_id = f"exact-balance-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give exact amount using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=100,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "exact_amount"}),
|
||||
)
|
||||
|
||||
# Try to spend $1 twice concurrently
|
||||
async def spend_exact():
|
||||
try:
|
||||
return await credit_system.spend_credits(
|
||||
user_id, 100, UsageTransactionMetadata(reason="Exact spend")
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return "FAILED"
|
||||
|
||||
# Both try to spend the full balance
|
||||
result1, result2 = await asyncio.gather(spend_exact(), spend_exact())
|
||||
|
||||
# Exactly one should succeed
|
||||
results = [result1, result2]
|
||||
successful = [
|
||||
r for r in results if r != "FAILED" and not isinstance(r, Exception)
|
||||
]
|
||||
failed = [r for r in results if r == "FAILED"]
|
||||
|
||||
assert len(successful) == 1, f"Expected 1 success, got {len(successful)}"
|
||||
assert len(failed) == 1, f"Expected 1 failure, got {len(failed)}"
|
||||
|
||||
# Balance should be exactly 0, never negative
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_onboarding_reward_idempotency(server: SpinTestServer):
|
||||
"""Test that onboarding rewards are idempotent (can't be claimed twice)."""
|
||||
user_id = f"onboarding-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Use WELCOME step which is defined in the OnboardingStep enum
|
||||
# Try to claim same reward multiple times concurrently
|
||||
async def claim_reward():
|
||||
try:
|
||||
result = await credit_system.onboarding_reward(
|
||||
user_id, 500, prisma.enums.OnboardingStep.WELCOME
|
||||
)
|
||||
return "SUCCESS" if result else "DUPLICATE"
|
||||
except Exception as e:
|
||||
print(f"Claim reward failed: {e}")
|
||||
return "FAILED"
|
||||
|
||||
# Try 5 concurrent claims of the same reward
|
||||
results = await asyncio.gather(*[claim_reward() for _ in range(5)])
|
||||
|
||||
# Count results
|
||||
success_count = results.count("SUCCESS")
|
||||
failed_count = results.count("FAILED")
|
||||
|
||||
# At least one should succeed, others should be duplicates
|
||||
assert success_count >= 1, f"At least one claim should succeed, got {results}"
|
||||
assert failed_count == 0, f"No claims should fail, got {results}"
|
||||
|
||||
# Check balance - should only have 500, not 2500
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 500, f"Expected balance 500, got {final_balance}"
|
||||
|
||||
# Check only one transaction exists
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"type": prisma.enums.CreditTransactionType.GRANT,
|
||||
"transactionKey": f"REWARD-{user_id}-WELCOME",
|
||||
}
|
||||
)
|
||||
assert (
|
||||
len(transactions) == 1
|
||||
), f"Expected 1 reward transaction, got {len(transactions)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_integer_overflow_protection(server: SpinTestServer):
|
||||
"""Test that integer overflow is prevented by clamping to POSTGRES_INT_MAX."""
|
||||
user_id = f"overflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Try to add amount that would overflow
|
||||
max_int = POSTGRES_INT_MAX
|
||||
|
||||
# First, set balance near max
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": max_int - 100},
|
||||
"update": {"balance": max_int - 100},
|
||||
},
|
||||
)
|
||||
|
||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=200,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "overflow_protection"}),
|
||||
)
|
||||
|
||||
# Balance should be clamped to max_int, not overflowed
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == max_int
|
||||
), f"Balance should be clamped to {max_int}, got {final_balance}"
|
||||
|
||||
# Verify transaction was created with clamped amount
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"type": prisma.enums.CreditTransactionType.TOP_UP,
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert len(transactions) > 0, "Transaction should be created"
|
||||
assert (
|
||||
transactions[0].runningBalance == max_int
|
||||
), "Transaction should show clamped balance"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_high_concurrency_stress(server: SpinTestServer):
|
||||
"""Stress test with many concurrent operations."""
|
||||
user_id = f"stress-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Initial balance using internal method (bypasses Stripe)
|
||||
initial_balance = 10000 # $100
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=initial_balance,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "stress_test_balance"}),
|
||||
)
|
||||
|
||||
# Run many concurrent operations
|
||||
async def random_operation(idx: int):
|
||||
operation = random.choice(["spend", "check"])
|
||||
|
||||
if operation == "spend":
|
||||
amount = random.randint(1, 50) # $0.01 to $0.50
|
||||
try:
|
||||
return (
|
||||
"spend",
|
||||
amount,
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(reason=f"Stress {idx}"),
|
||||
),
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
return ("spend_failed", amount, None)
|
||||
else:
|
||||
balance = await credit_system.get_credits(user_id)
|
||||
return ("check", 0, balance)
|
||||
|
||||
# Run 100 concurrent operations
|
||||
results = await asyncio.gather(
|
||||
*[random_operation(i) for i in range(100)], return_exceptions=True
|
||||
)
|
||||
|
||||
# Calculate expected final balance
|
||||
total_spent = sum(
|
||||
r[1]
|
||||
for r in results
|
||||
if not isinstance(r, Exception) and isinstance(r, tuple) and r[0] == "spend"
|
||||
)
|
||||
expected_balance = initial_balance - total_spent
|
||||
|
||||
# Verify final balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == expected_balance
|
||||
), f"Expected {expected_balance}, got {final_balance}"
|
||||
assert final_balance >= 0, "Balance went negative!"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestServer):
|
||||
"""Test multiple concurrent spends when there's sufficient balance for all."""
|
||||
user_id = f"multi-spend-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Give user 150 balance ($1.50) using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=150,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "sufficient_balance"}),
|
||||
)
|
||||
|
||||
# Track individual timing to see serialization
|
||||
timings = {}
|
||||
|
||||
async def spend_with_detailed_timing(amount: int, label: str):
|
||||
start = asyncio.get_event_loop().time()
|
||||
try:
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"concurrent-{label}",
|
||||
reason=f"Concurrent spend {label}",
|
||||
),
|
||||
)
|
||||
end = asyncio.get_event_loop().time()
|
||||
timings[label] = {"start": start, "end": end, "duration": end - start}
|
||||
return f"{label}-SUCCESS"
|
||||
except Exception as e:
|
||||
end = asyncio.get_event_loop().time()
|
||||
timings[label] = {
|
||||
"start": start,
|
||||
"end": end,
|
||||
"duration": end - start,
|
||||
"error": str(e),
|
||||
}
|
||||
return f"{label}-FAILED: {e}"
|
||||
|
||||
# Run concurrent spends: 10, 20, 30 (total 60, well under 150)
|
||||
overall_start = asyncio.get_event_loop().time()
|
||||
results = await asyncio.gather(
|
||||
spend_with_detailed_timing(10, "spend-10"),
|
||||
spend_with_detailed_timing(20, "spend-20"),
|
||||
spend_with_detailed_timing(30, "spend-30"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
overall_end = asyncio.get_event_loop().time()
|
||||
|
||||
print(f"Results: {results}")
|
||||
print(f"Overall duration: {overall_end - overall_start:.4f}s")
|
||||
|
||||
# Analyze timing to detect serialization vs true concurrency
|
||||
print("\nTiming analysis:")
|
||||
for label, timing in timings.items():
|
||||
print(
|
||||
f" {label}: started at {timing['start']:.4f}, ended at {timing['end']:.4f}, duration {timing['duration']:.4f}s"
|
||||
)
|
||||
|
||||
# Check if operations overlapped (true concurrency) or were serialized
|
||||
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
|
||||
print("\nExecution order by start time:")
|
||||
for i, (label, timing) in enumerate(sorted_timings):
|
||||
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
|
||||
|
||||
# Check for overlap (true concurrency) vs serialization
|
||||
overlaps = []
|
||||
for i in range(len(sorted_timings) - 1):
|
||||
current = sorted_timings[i]
|
||||
next_op = sorted_timings[i + 1]
|
||||
if current[1]["end"] > next_op[1]["start"]:
|
||||
overlaps.append(f"{current[0]} overlaps with {next_op[0]}")
|
||||
|
||||
if overlaps:
|
||||
print(f"✅ TRUE CONCURRENCY detected: {overlaps}")
|
||||
else:
|
||||
print("🔒 SERIALIZATION detected: No overlapping execution times")
|
||||
|
||||
# Check final balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
print(f"Final balance: {final_balance}")
|
||||
|
||||
# Count successes/failures
|
||||
successful = [r for r in results if "SUCCESS" in str(r)]
|
||||
failed = [r for r in results if "FAILED" in str(r)]
|
||||
|
||||
print(f"Successful: {len(successful)}, Failed: {len(failed)}")
|
||||
|
||||
# All should succeed since 150 - (10 + 20 + 30) = 90 > 0
|
||||
assert (
|
||||
len(successful) == 3
|
||||
), f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
|
||||
assert final_balance == 90, f"Expected balance 90, got {final_balance}"
|
||||
|
||||
# Check transaction timestamps to confirm database-level serialization
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE},
|
||||
order={"createdAt": "asc"},
|
||||
)
|
||||
print("\nDatabase transaction order (by createdAt):")
|
||||
for i, tx in enumerate(transactions):
|
||||
print(
|
||||
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
|
||||
)
|
||||
|
||||
# Verify running balances are chronologically consistent (ordered by createdAt)
|
||||
actual_balances = [
|
||||
tx.runningBalance for tx in transactions if tx.runningBalance is not None
|
||||
]
|
||||
print(f"Running balances: {actual_balances}")
|
||||
|
||||
# The balances should be valid intermediate states regardless of execution order
|
||||
# Starting balance: 150, spending 10+20+30=60, so final should be 90
|
||||
# The intermediate balances depend on execution order but should all be valid
|
||||
expected_possible_balances = {
|
||||
# If order is 10, 20, 30: [140, 120, 90]
|
||||
# If order is 10, 30, 20: [140, 110, 90]
|
||||
# If order is 20, 10, 30: [130, 120, 90]
|
||||
# If order is 20, 30, 10: [130, 100, 90]
|
||||
# If order is 30, 10, 20: [120, 110, 90]
|
||||
# If order is 30, 20, 10: [120, 100, 90]
|
||||
90,
|
||||
100,
|
||||
110,
|
||||
120,
|
||||
130,
|
||||
140, # All possible intermediate balances
|
||||
}
|
||||
|
||||
# Verify all balances are valid intermediate states
|
||||
for balance in actual_balances:
|
||||
assert (
|
||||
balance in expected_possible_balances
|
||||
), f"Invalid balance {balance}, expected one of {expected_possible_balances}"
|
||||
|
||||
# Final balance should always be 90 (150 - 60)
|
||||
assert (
|
||||
min(actual_balances) == 90
|
||||
), f"Final balance should be 90, got {min(actual_balances)}"
|
||||
|
||||
# The final transaction should always have balance 90
|
||||
# The other transactions should have valid intermediate balances
|
||||
assert (
|
||||
90 in actual_balances
|
||||
), f"Final balance 90 should be in actual_balances: {actual_balances}"
|
||||
|
||||
# All balances should be >= 90 (the final state)
|
||||
assert all(
|
||||
balance >= 90 for balance in actual_balances
|
||||
), f"All balances should be >= 90, got {actual_balances}"
|
||||
|
||||
# CRITICAL: Transactions are atomic but can complete in any order
|
||||
# What matters is that all running balances are valid intermediate states
|
||||
# Each balance should be between 90 (final) and 140 (after first transaction)
|
||||
for balance in actual_balances:
|
||||
assert (
|
||||
90 <= balance <= 140
|
||||
), f"Balance {balance} is outside valid range [90, 140]"
|
||||
|
||||
# Final balance (minimum) should always be 90
|
||||
assert (
|
||||
min(actual_balances) == 90
|
||||
), f"Final balance should be 90, got {min(actual_balances)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_prove_database_locking_behavior(server: SpinTestServer):
|
||||
"""Definitively prove whether database locking causes waiting vs failures."""
|
||||
user_id = f"locking-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set balance to exact amount that can handle all spends using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=60, # Exactly 10+20+30
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "exact_amount_test"}),
|
||||
)
|
||||
|
||||
async def spend_with_precise_timing(amount: int, label: str):
|
||||
request_start = asyncio.get_event_loop().time()
|
||||
db_operation_start = asyncio.get_event_loop().time()
|
||||
try:
|
||||
# Add a small delay to increase chance of true concurrency
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
db_operation_start = asyncio.get_event_loop().time()
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"locking-{label}",
|
||||
reason=f"Locking test {label}",
|
||||
),
|
||||
)
|
||||
db_operation_end = asyncio.get_event_loop().time()
|
||||
|
||||
return {
|
||||
"label": label,
|
||||
"status": "SUCCESS",
|
||||
"request_start": request_start,
|
||||
"db_start": db_operation_start,
|
||||
"db_end": db_operation_end,
|
||||
"db_duration": db_operation_end - db_operation_start,
|
||||
}
|
||||
except Exception as e:
|
||||
db_operation_end = asyncio.get_event_loop().time()
|
||||
return {
|
||||
"label": label,
|
||||
"status": "FAILED",
|
||||
"error": str(e),
|
||||
"request_start": request_start,
|
||||
"db_start": db_operation_start,
|
||||
"db_end": db_operation_end,
|
||||
"db_duration": db_operation_end - db_operation_start,
|
||||
}
|
||||
|
||||
# Launch all requests simultaneously
|
||||
results = await asyncio.gather(
|
||||
spend_with_precise_timing(10, "A"),
|
||||
spend_with_precise_timing(20, "B"),
|
||||
spend_with_precise_timing(30, "C"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
print("\n🔍 LOCKING BEHAVIOR ANALYSIS:")
|
||||
print("=" * 50)
|
||||
|
||||
successful = [
|
||||
r for r in results if isinstance(r, dict) and r.get("status") == "SUCCESS"
|
||||
]
|
||||
failed = [
|
||||
r for r in results if isinstance(r, dict) and r.get("status") == "FAILED"
|
||||
]
|
||||
|
||||
print(f"✅ Successful operations: {len(successful)}")
|
||||
print(f"❌ Failed operations: {len(failed)}")
|
||||
|
||||
if len(failed) > 0:
|
||||
print(
|
||||
"\n🚫 CONCURRENT FAILURES - Some requests failed due to insufficient balance:"
|
||||
)
|
||||
for result in failed:
|
||||
if isinstance(result, dict):
|
||||
print(
|
||||
f" {result['label']}: {result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
if len(successful) == 3:
|
||||
print(
|
||||
"\n🔒 SERIALIZATION CONFIRMED - All requests succeeded, indicating they were queued:"
|
||||
)
|
||||
|
||||
# Sort by actual execution time to see order
|
||||
dict_results = [r for r in results if isinstance(r, dict)]
|
||||
sorted_results = sorted(dict_results, key=lambda x: x["db_start"])
|
||||
|
||||
for i, result in enumerate(sorted_results):
|
||||
print(
|
||||
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
|
||||
)
|
||||
|
||||
# Check if any operations overlapped at the database level
|
||||
print("\n⏱️ Database operation timeline:")
|
||||
for result in sorted_results:
|
||||
print(
|
||||
f" {result['label']}: {result['db_start']:.4f} -> {result['db_end']:.4f}"
|
||||
)
|
||||
|
||||
# Verify final state
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
print(f"\n💰 Final balance: {final_balance}")
|
||||
|
||||
if len(successful) == 3:
|
||||
assert (
|
||||
final_balance == 0
|
||||
), f"If all succeeded, balance should be 0, got {final_balance}"
|
||||
print(
|
||||
"✅ CONCLUSION: Database row locking causes requests to WAIT and execute serially"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"❌ CONCLUSION: Some requests failed, indicating different concurrency behavior"
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
277
autogpt_platform/backend/backend/data/credit_integration_test.py
Normal file
277
autogpt_platform/backend/backend/data/credit_integration_test.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
Integration tests for credit system to catch SQL enum casting issues.
|
||||
|
||||
These tests run actual database operations to ensure SQL queries work correctly,
|
||||
which would have caught the CreditTransactionType enum casting bug.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
BetaUserCredit,
|
||||
UsageTransactionMetadata,
|
||||
get_auto_top_up,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def cleanup_test_user():
|
||||
"""Clean up test user data before and after tests."""
|
||||
import uuid
|
||||
|
||||
user_id = str(uuid.uuid4()) # Use unique user ID for each test
|
||||
|
||||
# Create the user first
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"topUpConfig": SafeJson({}),
|
||||
"timezone": "UTC",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# User might already exist, that's fine
|
||||
pass
|
||||
|
||||
yield user_id
|
||||
|
||||
# Cleanup after test
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
# Clear auto-top-up config before deleting user
|
||||
await User.prisma().update(
|
||||
where={"id": user_id}, data={"topUpConfig": SafeJson({})}
|
||||
)
|
||||
await User.prisma().delete(where={"id": user_id})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_credit_transaction_enum_casting_integration(cleanup_test_user):
|
||||
"""
|
||||
Integration test to verify CreditTransactionType enum casting works in SQL queries.
|
||||
|
||||
This test would have caught the enum casting bug where PostgreSQL expected
|
||||
platform."CreditTransactionType" but got "CreditTransactionType".
|
||||
"""
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Test each transaction type to ensure enum casting works
|
||||
test_cases = [
|
||||
(CreditTransactionType.TOP_UP, 100, "Test top-up"),
|
||||
(CreditTransactionType.USAGE, -50, "Test usage"),
|
||||
(CreditTransactionType.GRANT, 200, "Test grant"),
|
||||
(CreditTransactionType.REFUND, -25, "Test refund"),
|
||||
(CreditTransactionType.CARD_CHECK, 0, "Test card check"),
|
||||
]
|
||||
|
||||
for transaction_type, amount, reason in test_cases:
|
||||
metadata = SafeJson({"reason": reason, "test": "enum_casting"})
|
||||
|
||||
# This call would fail with enum casting error before the fix
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=amount,
|
||||
transaction_type=transaction_type,
|
||||
metadata=metadata,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify transaction was created with correct type
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.type == transaction_type
|
||||
assert transaction.amount == amount
|
||||
assert transaction.metadata is not None
|
||||
|
||||
# Verify metadata content
|
||||
assert transaction.metadata["reason"] == reason
|
||||
assert transaction.metadata["test"] == "enum_casting"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auto_top_up_integration(cleanup_test_user, monkeypatch):
|
||||
"""
|
||||
Integration test for auto-top-up functionality that triggers enum casting.
|
||||
|
||||
This tests the complete auto-top-up flow which involves SQL queries with
|
||||
CreditTransactionType enums, ensuring enum casting works end-to-end.
|
||||
"""
|
||||
# Enable credits for this test
|
||||
from backend.data.credit import settings
|
||||
|
||||
monkeypatch.setattr(settings.config, "enable_credit", True)
|
||||
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
|
||||
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
|
||||
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# First add some initial credits so we can test the configuration and subsequent behavior
|
||||
balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=50, # Below threshold that we'll set
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=SafeJson({"reason": "Initial credits before auto top-up config"}),
|
||||
)
|
||||
assert balance == 50
|
||||
|
||||
# Configure auto top-up with threshold above current balance
|
||||
config = AutoTopUpConfig(threshold=100, amount=500)
|
||||
await set_auto_top_up(user_id, config)
|
||||
|
||||
# Verify configuration was saved but no immediate top-up occurred
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
assert current_balance == 50 # Balance should be unchanged
|
||||
|
||||
# Simulate spending credits that would trigger auto top-up
|
||||
# This involves multiple SQL operations with enum casting
|
||||
try:
|
||||
metadata = UsageTransactionMetadata(reason="Test spend to trigger auto top-up")
|
||||
await credit_system.spend_credits(user_id=user_id, cost=10, metadata=metadata)
|
||||
|
||||
# The auto top-up mechanism should have been triggered
|
||||
# Verify the transaction types were handled correctly
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
# Should have at least: GRANT (initial), USAGE (spend), and TOP_UP (auto top-up)
|
||||
assert len(transactions) >= 3
|
||||
|
||||
# Verify different transaction types exist and enum casting worked
|
||||
transaction_types = {t.type for t in transactions}
|
||||
assert CreditTransactionType.GRANT in transaction_types
|
||||
assert CreditTransactionType.USAGE in transaction_types
|
||||
assert (
|
||||
CreditTransactionType.TOP_UP in transaction_types
|
||||
) # Auto top-up should have triggered
|
||||
|
||||
except Exception as e:
|
||||
# If this fails with enum casting error, the test successfully caught the bug
|
||||
if "CreditTransactionType" in str(e) and (
|
||||
"cast" in str(e).lower() or "type" in str(e).lower()
|
||||
):
|
||||
pytest.fail(f"Enum casting error detected: {e}")
|
||||
else:
|
||||
# Re-raise other unexpected errors
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_enum_casting_integration(cleanup_test_user):
|
||||
"""
|
||||
Integration test for _enable_transaction with enum casting.
|
||||
|
||||
Tests the scenario where inactive transactions are enabled, which also
|
||||
involves SQL queries with CreditTransactionType enum casting.
|
||||
"""
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Create an inactive transaction
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=100,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"reason": "Inactive transaction test"}),
|
||||
is_active=False, # Create as inactive
|
||||
)
|
||||
|
||||
# Balance should be 0 since transaction is inactive
|
||||
assert balance == 0
|
||||
|
||||
# Enable the transaction with new metadata
|
||||
enable_metadata = SafeJson(
|
||||
{
|
||||
"payment_method": "test_payment",
|
||||
"activation_reason": "Integration test activation",
|
||||
}
|
||||
)
|
||||
|
||||
# This would fail with enum casting error before the fix
|
||||
final_balance = await credit_system._enable_transaction(
|
||||
transaction_key=tx_key,
|
||||
user_id=user_id,
|
||||
metadata=enable_metadata,
|
||||
)
|
||||
|
||||
# Now balance should reflect the activated transaction
|
||||
assert final_balance == 100
|
||||
|
||||
# Verify transaction was properly enabled with correct enum type
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.isActive is True
|
||||
assert transaction.type == CreditTransactionType.TOP_UP
|
||||
assert transaction.runningBalance == 100
|
||||
|
||||
# Verify metadata was updated
|
||||
assert transaction.metadata is not None
|
||||
assert transaction.metadata["payment_method"] == "test_payment"
|
||||
assert transaction.metadata["activation_reason"] == "Integration test activation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_auto_top_up_configuration_storage(cleanup_test_user, monkeypatch):
|
||||
"""
|
||||
Test that auto-top-up configuration is properly stored and retrieved.
|
||||
|
||||
The immediate top-up logic is handled by the API routes, not the core
|
||||
set_auto_top_up function. This test verifies the configuration is correctly
|
||||
saved and can be retrieved.
|
||||
"""
|
||||
# Enable credits for this test
|
||||
from backend.data.credit import settings
|
||||
|
||||
monkeypatch.setattr(settings.config, "enable_credit", True)
|
||||
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
|
||||
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
|
||||
|
||||
user_id = cleanup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Set initial balance
|
||||
balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=50,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
metadata=SafeJson({"reason": "Initial balance for config test"}),
|
||||
)
|
||||
|
||||
assert balance == 50
|
||||
|
||||
# Configure auto top-up
|
||||
config = AutoTopUpConfig(threshold=100, amount=200)
|
||||
await set_auto_top_up(user_id, config)
|
||||
|
||||
# Verify the configuration was saved
|
||||
retrieved_config = await get_auto_top_up(user_id)
|
||||
assert retrieved_config.threshold == config.threshold
|
||||
assert retrieved_config.amount == config.amount
|
||||
|
||||
# Verify balance is unchanged (no immediate top-up from set_auto_top_up)
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 50 # Should be unchanged
|
||||
|
||||
# Verify no immediate auto-top-up transaction was created by set_auto_top_up
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
# Should only have the initial GRANT transaction
|
||||
assert len(transactions) == 1
|
||||
assert transactions[0].type == CreditTransactionType.GRANT
|
||||
141
autogpt_platform/backend/backend/data/credit_metadata_test.py
Normal file
141
autogpt_platform/backend/backend/data/credit_metadata_test.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Tests for credit system metadata handling to ensure JSON casting works correctly.
|
||||
|
||||
This test verifies that metadata parameters are properly serialized when passed
|
||||
to raw SQL queries with JSONB columns.
|
||||
"""
|
||||
|
||||
# type: ignore
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
|
||||
from backend.data.credit import BetaUserCredit
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user():
|
||||
"""Setup test user and cleanup after test."""
|
||||
user_id = DEFAULT_USER_ID
|
||||
|
||||
# Cleanup before test
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
|
||||
yield user_id
|
||||
|
||||
# Cleanup after test
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_metadata_json_serialization(setup_test_user):
|
||||
"""Test that metadata is properly serialized for JSONB column in raw SQL."""
|
||||
user_id = setup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# Test with complex metadata that would fail if not properly serialized
|
||||
complex_metadata = SafeJson(
|
||||
{
|
||||
"graph_exec_id": "test-12345",
|
||||
"reason": "Testing metadata serialization",
|
||||
"nested_data": {
|
||||
"key1": "value1",
|
||||
"key2": ["array", "of", "values"],
|
||||
"key3": {"deeply": {"nested": "object"}},
|
||||
},
|
||||
"special_chars": "Testing 'quotes' and \"double quotes\" and unicode: 🚀",
|
||||
}
|
||||
)
|
||||
|
||||
# This should work without throwing a JSONB casting error
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500, # $5 top-up
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=complex_metadata,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify the transaction was created successfully
|
||||
assert balance == 500
|
||||
|
||||
# Verify the metadata was stored correctly in the database
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.metadata is not None
|
||||
|
||||
# Verify the metadata contains our complex data
|
||||
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
|
||||
assert metadata_dict["graph_exec_id"] == "test-12345"
|
||||
assert metadata_dict["reason"] == "Testing metadata serialization"
|
||||
assert metadata_dict["nested_data"]["key1"] == "value1"
|
||||
assert metadata_dict["nested_data"]["key3"]["deeply"]["nested"] == "object"
|
||||
assert (
|
||||
metadata_dict["special_chars"]
|
||||
== "Testing 'quotes' and \"double quotes\" and unicode: 🚀"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_metadata_serialization(setup_test_user):
|
||||
"""Test that _enable_transaction also handles metadata JSON serialization correctly."""
|
||||
user_id = setup_test_user
|
||||
credit_system = BetaUserCredit(1000)
|
||||
|
||||
# First create an inactive transaction
|
||||
balance, tx_key = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=300,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"initial": "inactive_transaction"}),
|
||||
is_active=False, # Create as inactive
|
||||
)
|
||||
|
||||
# Initial balance should be 0 because transaction is inactive
|
||||
assert balance == 0
|
||||
|
||||
# Now enable the transaction with new metadata
|
||||
enable_metadata = SafeJson(
|
||||
{
|
||||
"payment_method": "stripe",
|
||||
"payment_intent": "pi_test_12345",
|
||||
"activation_reason": "Payment confirmed",
|
||||
"complex_data": {"array": [1, 2, 3], "boolean": True, "null_value": None},
|
||||
}
|
||||
)
|
||||
|
||||
# This should work without JSONB casting errors
|
||||
final_balance = await credit_system._enable_transaction(
|
||||
transaction_key=tx_key,
|
||||
user_id=user_id,
|
||||
metadata=enable_metadata,
|
||||
)
|
||||
|
||||
# Now balance should reflect the activated transaction
|
||||
assert final_balance == 300
|
||||
|
||||
# Verify the metadata was updated correctly
|
||||
transaction = await CreditTransaction.prisma().find_first(
|
||||
where={"userId": user_id, "transactionKey": tx_key}
|
||||
)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.isActive is True
|
||||
|
||||
# Verify the metadata was updated with enable_metadata
|
||||
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
|
||||
assert metadata_dict["payment_method"] == "stripe"
|
||||
assert metadata_dict["payment_intent"] == "pi_test_12345"
|
||||
assert metadata_dict["complex_data"]["array"] == [1, 2, 3]
|
||||
assert metadata_dict["complex_data"]["boolean"] is True
|
||||
assert metadata_dict["complex_data"]["null_value"] is None
|
||||
372
autogpt_platform/backend/backend/data/credit_refund_test.py
Normal file
372
autogpt_platform/backend/backend/data/credit_refund_test.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""
|
||||
Tests for credit system refund and dispute operations.
|
||||
|
||||
These tests ensure that refund operations (deduct_credits, handle_dispute)
|
||||
are atomic and maintain data consistency.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
credit_system = UserCredit()
|
||||
|
||||
# Test user ID for refund tests
|
||||
REFUND_TEST_USER_ID = "refund-test-user"
|
||||
|
||||
|
||||
async def setup_test_user_with_topup():
|
||||
"""Create a test user with initial balance and a top-up transaction."""
|
||||
# Clean up any existing data
|
||||
await CreditRefundRequest.prisma().delete_many(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
|
||||
|
||||
# Create user
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": REFUND_TEST_USER_ID,
|
||||
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||
"name": "Refund Test User",
|
||||
}
|
||||
)
|
||||
|
||||
# Create user balance
|
||||
await UserBalance.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"balance": 1000, # $10
|
||||
}
|
||||
)
|
||||
|
||||
# Create a top-up transaction that can be refunded
|
||||
topup_tx = await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 1000,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"transactionKey": "pi_test_12345",
|
||||
"runningBalance": 1000,
|
||||
"isActive": True,
|
||||
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||
}
|
||||
)
|
||||
|
||||
return topup_tx
|
||||
|
||||
|
||||
async def cleanup_test_user():
|
||||
"""Clean up test data."""
|
||||
await CreditRefundRequest.prisma().delete_many(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
|
||||
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_deduct_credits_atomic(server: SpinTestServer):
|
||||
"""Test that deduct_credits is atomic and creates transaction correctly."""
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
try:
|
||||
# Create a mock refund object
|
||||
refund = MagicMock(spec=stripe.Refund)
|
||||
refund.id = "re_test_refund_123"
|
||||
refund.payment_intent = topup_tx.transactionKey
|
||||
refund.amount = 500 # Refund $5 of the $10 top-up
|
||||
refund.status = "succeeded"
|
||||
refund.reason = "requested_by_customer"
|
||||
refund.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Create refund request record (simulating webhook flow)
|
||||
await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 500,
|
||||
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||
"reason": "Test refund",
|
||||
}
|
||||
)
|
||||
|
||||
# Call deduct_credits
|
||||
await credit_system.deduct_credits(refund)
|
||||
|
||||
# Verify the user's balance was deducted
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 500
|
||||
), f"Expected balance 500, got {user_balance.balance}"
|
||||
|
||||
# Verify refund transaction was created
|
||||
refund_tx = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"type": CreditTransactionType.REFUND,
|
||||
"transactionKey": refund.id,
|
||||
}
|
||||
)
|
||||
assert refund_tx is not None
|
||||
assert refund_tx.amount == -500
|
||||
assert refund_tx.runningBalance == 500
|
||||
assert refund_tx.isActive
|
||||
|
||||
# Verify refund request was updated
|
||||
refund_request = await CreditRefundRequest.prisma().find_first(
|
||||
where={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
}
|
||||
)
|
||||
assert refund_request is not None
|
||||
assert (
|
||||
refund_request.result
|
||||
== "The refund request has been approved, the amount will be credited back to your account."
|
||||
)
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_deduct_credits_user_not_found(server: SpinTestServer):
|
||||
"""Test that deduct_credits raises error if transaction not found (which means user doesn't exist)."""
|
||||
# Create a mock refund object that references a non-existent payment intent
|
||||
refund = MagicMock(spec=stripe.Refund)
|
||||
refund.id = "re_test_refund_nonexistent"
|
||||
refund.payment_intent = "pi_test_nonexistent" # This payment intent doesn't exist
|
||||
refund.amount = 500
|
||||
refund.status = "succeeded"
|
||||
refund.reason = "requested_by_customer"
|
||||
refund.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Should raise error for missing transaction
|
||||
with pytest.raises(Exception): # Should raise NotFoundError for missing transaction
|
||||
await credit_system.deduct_credits(refund)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.data.credit.settings")
|
||||
@patch("stripe.Dispute.modify")
|
||||
@patch("backend.data.credit.get_user_by_id")
|
||||
async def test_handle_dispute_with_sufficient_balance(
|
||||
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
|
||||
):
|
||||
"""Test handling dispute when user has sufficient balance (dispute gets closed)."""
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
try:
|
||||
# Mock settings to have a low tolerance threshold
|
||||
mock_settings.config.refund_credit_tolerance_threshold = 0
|
||||
|
||||
# Mock the user lookup
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Create a mock dispute object for small amount (user has 1000, disputing 100)
|
||||
dispute = MagicMock(spec=stripe.Dispute)
|
||||
dispute.id = "dp_test_dispute_123"
|
||||
dispute.payment_intent = topup_tx.transactionKey
|
||||
dispute.amount = 100 # Small dispute amount
|
||||
dispute.status = "pending"
|
||||
dispute.reason = "fraudulent"
|
||||
dispute.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Mock the close method to prevent real API calls
|
||||
dispute.close = MagicMock()
|
||||
|
||||
# Handle the dispute
|
||||
await credit_system.handle_dispute(dispute)
|
||||
|
||||
# Verify dispute.close() was called (since user has sufficient balance)
|
||||
dispute.close.assert_called_once()
|
||||
|
||||
# Verify no stripe evidence was added since dispute was closed
|
||||
mock_stripe_modify.assert_not_called()
|
||||
|
||||
# Verify the user's balance was NOT deducted (dispute was closed)
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 1000
|
||||
), f"Balance should remain 1000, got {user_balance.balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.data.credit.settings")
|
||||
@patch("stripe.Dispute.modify")
|
||||
@patch("backend.data.credit.get_user_by_id")
|
||||
async def test_handle_dispute_with_insufficient_balance(
|
||||
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
|
||||
):
|
||||
"""Test handling dispute when user has insufficient balance (evidence gets added)."""
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
# Save original method for restoration before any try blocks
|
||||
original_get_history = credit_system.get_transaction_history
|
||||
|
||||
try:
|
||||
# Mock settings to have a high tolerance threshold so dispute isn't closed
|
||||
mock_settings.config.refund_credit_tolerance_threshold = 2000
|
||||
|
||||
# Mock the user lookup
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock the transaction history method to return an async result
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_history = MagicMock()
|
||||
mock_history.transactions = []
|
||||
credit_system.get_transaction_history = AsyncMock(return_value=mock_history)
|
||||
|
||||
# Create a mock dispute object for full amount (user has 1000, disputing 1000)
|
||||
dispute = MagicMock(spec=stripe.Dispute)
|
||||
dispute.id = "dp_test_dispute_pending"
|
||||
dispute.payment_intent = topup_tx.transactionKey
|
||||
dispute.amount = 1000
|
||||
dispute.status = "warning_needs_response"
|
||||
dispute.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Mock the close method to prevent real API calls
|
||||
dispute.close = MagicMock()
|
||||
|
||||
# Handle the dispute (evidence should be added)
|
||||
await credit_system.handle_dispute(dispute)
|
||||
|
||||
# Verify dispute.close() was NOT called (insufficient balance after tolerance)
|
||||
dispute.close.assert_not_called()
|
||||
|
||||
# Verify stripe evidence was added since dispute wasn't closed
|
||||
mock_stripe_modify.assert_called_once()
|
||||
|
||||
# Verify the user's balance was NOT deducted (handle_dispute doesn't deduct credits)
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
assert user_balance.balance == 1000, "Balance should remain unchanged"
|
||||
|
||||
finally:
|
||||
credit_system.get_transaction_history = original_get_history
|
||||
await cleanup_test_user()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_refunds(server: SpinTestServer):
|
||||
"""Test that concurrent refunds are handled atomically."""
|
||||
import asyncio
|
||||
|
||||
topup_tx = await setup_test_user_with_topup()
|
||||
|
||||
try:
|
||||
# Create multiple refund requests
|
||||
refund_requests = []
|
||||
for i in range(5):
|
||||
req = await CreditRefundRequest.prisma().create(
|
||||
data={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"amount": 100, # $1 each
|
||||
"transactionKey": topup_tx.transactionKey,
|
||||
"reason": f"Test refund {i}",
|
||||
}
|
||||
)
|
||||
refund_requests.append(req)
|
||||
|
||||
# Create refund tasks to run concurrently
|
||||
async def process_refund(index: int):
|
||||
refund = MagicMock(spec=stripe.Refund)
|
||||
refund.id = f"re_test_concurrent_{index}"
|
||||
refund.payment_intent = topup_tx.transactionKey
|
||||
refund.amount = 100 # $1 refund
|
||||
refund.status = "succeeded"
|
||||
refund.reason = "requested_by_customer"
|
||||
refund.created = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
try:
|
||||
await credit_system.deduct_credits(refund)
|
||||
return "success"
|
||||
except Exception as e:
|
||||
return f"error: {e}"
|
||||
|
||||
# Run refunds concurrently
|
||||
results = await asyncio.gather(
|
||||
*[process_refund(i) for i in range(5)], return_exceptions=True
|
||||
)
|
||||
|
||||
# All should succeed
|
||||
assert all(r == "success" for r in results), f"Some refunds failed: {results}"
|
||||
|
||||
# Verify final balance - with non-atomic implementation, this will demonstrate race condition
|
||||
# EXPECTED BEHAVIOR: Due to race conditions, not all refunds will be properly processed
|
||||
# The balance will be incorrect (higher than expected) showing lost updates
|
||||
user_balance = await UserBalance.prisma().find_unique(
|
||||
where={"userId": REFUND_TEST_USER_ID}
|
||||
)
|
||||
assert user_balance is not None
|
||||
|
||||
# With atomic implementation, this should be 500 (1000 - 5*100)
|
||||
# With current non-atomic implementation, this will likely be wrong due to race conditions
|
||||
print(f"DEBUG: Final balance = {user_balance.balance}, expected = 500")
|
||||
|
||||
# With atomic implementation, all 5 refunds should process correctly
|
||||
assert (
|
||||
user_balance.balance == 500
|
||||
), f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
|
||||
|
||||
# Verify all refund transactions exist
|
||||
refund_txs = await CreditTransaction.prisma().find_many(
|
||||
where={
|
||||
"userId": REFUND_TEST_USER_ID,
|
||||
"type": CreditTransactionType.REFUND,
|
||||
}
|
||||
)
|
||||
assert (
|
||||
len(refund_txs) == 5
|
||||
), f"Expected 5 refund transactions, got {len(refund_txs)}"
|
||||
|
||||
running_balances: set[int] = {
|
||||
tx.runningBalance for tx in refund_txs if tx.runningBalance is not None
|
||||
}
|
||||
|
||||
# Verify all balances are valid intermediate states
|
||||
for balance in running_balances:
|
||||
assert (
|
||||
500 <= balance <= 1000
|
||||
), f"Invalid balance {balance}, should be between 500 and 1000"
|
||||
|
||||
# Final balance should be present
|
||||
assert (
|
||||
500 in running_balances
|
||||
), f"Final balance 500 should be in {running_balances}"
|
||||
|
||||
# All balances should be unique and form a valid sequence
|
||||
sorted_balances = sorted(running_balances, reverse=True)
|
||||
assert (
|
||||
len(sorted_balances) == 5
|
||||
), f"Expected 5 unique balances, got {len(sorted_balances)}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user()
|
||||
@@ -1,8 +1,8 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.models import CreditTransaction
|
||||
from prisma.models import CreditTransaction, UserBalance
|
||||
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.data.block import get_block
|
||||
@@ -19,14 +19,24 @@ user_credit = BetaUserCredit(REFILL_VALUE)
|
||||
|
||||
async def disable_test_user_transactions():
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": DEFAULT_USER_ID})
|
||||
# Also reset the balance to 0 and set updatedAt to old date to trigger monthly refill
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||
"update": {"balance": 0, "updatedAt": old_date},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def top_up(amount: int):
|
||||
await user_credit._add_transaction(
|
||||
balance, _ = await user_credit._add_transaction(
|
||||
DEFAULT_USER_ID,
|
||||
amount,
|
||||
CreditTransactionType.TOP_UP,
|
||||
)
|
||||
return balance
|
||||
|
||||
|
||||
async def spend_credits(entry: NodeExecutionEntry) -> int:
|
||||
@@ -111,29 +121,90 @@ async def test_block_credit_top_up(server: SpinTestServer):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_credit_reset(server: SpinTestServer):
|
||||
"""Test that BetaUserCredit provides monthly refills correctly."""
|
||||
await disable_test_user_transactions()
|
||||
month1 = 1
|
||||
month2 = 2
|
||||
|
||||
# set the calendar to month 2 but use current time from now
|
||||
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
|
||||
month=month2, day=1
|
||||
)
|
||||
month2credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
# Save original time_now function for restoration
|
||||
original_time_now = user_credit.time_now
|
||||
|
||||
# Month 1 result should only affect month 1
|
||||
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
|
||||
month=month1, day=1
|
||||
)
|
||||
month1credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
await top_up(100)
|
||||
assert await user_credit.get_credits(DEFAULT_USER_ID) == month1credit + 100
|
||||
try:
|
||||
# Test month 1 behavior
|
||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||
user_credit.time_now = lambda: month1
|
||||
|
||||
# Month 2 balance is unaffected
|
||||
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
|
||||
month=month2, day=1
|
||||
)
|
||||
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
|
||||
# First call in month 1 should trigger refill
|
||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||
|
||||
# Manually create a transaction with month 1 timestamp to establish history
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": 100,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"runningBalance": 1100,
|
||||
"isActive": True,
|
||||
"createdAt": month1, # Set specific timestamp
|
||||
}
|
||||
)
|
||||
|
||||
# Update user balance to match
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": DEFAULT_USER_ID},
|
||||
data={
|
||||
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||
"update": {"balance": 1100},
|
||||
},
|
||||
)
|
||||
|
||||
# Now test month 2 behavior
|
||||
month2 = datetime.now(timezone.utc).replace(month=2, day=1)
|
||||
user_credit.time_now = lambda: month2
|
||||
|
||||
# In month 2, since balance (1100) > refill (1000), no refill should happen
|
||||
month2_balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert month2_balance == 1100 # Balance persists, no reset
|
||||
|
||||
# Now test the refill behavior when balance is low
|
||||
# Set balance below refill threshold
|
||||
await UserBalance.prisma().update(
|
||||
where={"userId": DEFAULT_USER_ID}, data={"balance": 400}
|
||||
)
|
||||
|
||||
# Create a month 2 transaction to update the last transaction time
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"amount": -700, # Spent 700 to get to 400
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"runningBalance": 400,
|
||||
"isActive": True,
|
||||
"createdAt": month2,
|
||||
}
|
||||
)
|
||||
|
||||
# Move to month 3
|
||||
month3 = datetime.now(timezone.utc).replace(month=3, day=1)
|
||||
user_credit.time_now = lambda: month3
|
||||
|
||||
# Should get refilled since balance (400) < refill value (1000)
|
||||
month3_balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert month3_balance == REFILL_VALUE # Should be refilled to 1000
|
||||
|
||||
# Verify the refill transaction was created
|
||||
refill_tx = await CreditTransaction.prisma().find_first(
|
||||
where={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"type": CreditTransactionType.GRANT,
|
||||
"transactionKey": {"contains": "MONTHLY-CREDIT-TOP-UP"},
|
||||
},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert refill_tx is not None, "Monthly refill transaction should be created"
|
||||
assert refill_tx.amount == 600, "Refill should be 600 (1000 - 400)"
|
||||
finally:
|
||||
# Restore original time_now function
|
||||
user_credit.time_now = original_time_now
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
361
autogpt_platform/backend/backend/data/credit_underflow_test.py
Normal file
361
autogpt_platform/backend/backend/data/credit_underflow_test.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Test underflow protection for cumulative refunds and negative transactions.
|
||||
|
||||
This test ensures that when multiple large refunds are processed, the user balance
|
||||
doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for underflow tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their transactions."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
||||
"""Debug underflow behavior step by step."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"debug-underflow-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
print(f"POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
|
||||
|
||||
# Test 1: Set up balance close to underflow threshold
|
||||
print("\n=== Test 1: Setting up balance close to underflow threshold ===")
|
||||
# First, manually set balance to a value very close to POSTGRES_INT_MIN
|
||||
# We'll set it to POSTGRES_INT_MIN + 100, then try to subtract 200
|
||||
# This should trigger underflow protection: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
|
||||
initial_balance_target = POSTGRES_INT_MIN + 100
|
||||
|
||||
# Use direct database update to set the balance close to underflow
|
||||
from prisma.models import UserBalance
|
||||
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||
"update": {"balance": initial_balance_target},
|
||||
},
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
print(f"Set balance to: {current_balance}")
|
||||
assert current_balance == initial_balance_target
|
||||
|
||||
# Test 2: Apply amount that should cause underflow
|
||||
print("\n=== Test 2: Testing underflow protection ===")
|
||||
test_amount = (
|
||||
-200
|
||||
) # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
|
||||
expected_without_protection = current_balance + test_amount
|
||||
print(f"Current balance: {current_balance}")
|
||||
print(f"Test amount: {test_amount}")
|
||||
print(f"Without protection would be: {expected_without_protection}")
|
||||
print(f"Should be clamped to POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
|
||||
|
||||
# Apply the amount that should trigger underflow protection
|
||||
balance_result, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=test_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
print(f"Actual result: {balance_result}")
|
||||
|
||||
# Check if underflow protection worked
|
||||
assert (
|
||||
balance_result == POSTGRES_INT_MIN
|
||||
), f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
|
||||
|
||||
# Test 3: Edge case - exactly at POSTGRES_INT_MIN
|
||||
print("\n=== Test 3: Testing exact POSTGRES_INT_MIN boundary ===")
|
||||
# Set balance to exactly POSTGRES_INT_MIN
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||
"update": {"balance": POSTGRES_INT_MIN},
|
||||
},
|
||||
)
|
||||
|
||||
edge_balance = await credit_system.get_credits(user_id)
|
||||
print(f"Balance set to exactly POSTGRES_INT_MIN: {edge_balance}")
|
||||
|
||||
# Try to subtract 1 - should stay at POSTGRES_INT_MIN
|
||||
edge_result, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-1,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
print(f"After subtracting 1: {edge_result}")
|
||||
|
||||
assert (
|
||||
edge_result == POSTGRES_INT_MIN
|
||||
), f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
||||
"""Test that large cumulative refunds don't cause integer underflow."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"underflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold to test the protection
|
||||
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
|
||||
# This should trigger underflow protection
|
||||
from prisma.models import UserBalance
|
||||
|
||||
test_balance = POSTGRES_INT_MIN + 1000
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": test_balance},
|
||||
"update": {"balance": test_balance},
|
||||
},
|
||||
)
|
||||
|
||||
current_balance = await credit_system.get_credits(user_id)
|
||||
assert current_balance == test_balance
|
||||
|
||||
# Try to deduct amount that would cause underflow: test_balance + (-2000) = POSTGRES_INT_MIN - 1000
|
||||
underflow_amount = -2000
|
||||
expected_without_protection = (
|
||||
current_balance + underflow_amount
|
||||
) # Should be POSTGRES_INT_MIN - 1000
|
||||
|
||||
# Use _add_transaction directly with amount that would cause underflow
|
||||
final_balance, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=underflow_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False, # Allow going negative for refunds
|
||||
)
|
||||
|
||||
# Balance should be clamped to POSTGRES_INT_MIN, not the calculated underflow value
|
||||
assert (
|
||||
final_balance == POSTGRES_INT_MIN
|
||||
), f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
assert (
|
||||
final_balance > expected_without_protection
|
||||
), f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
|
||||
|
||||
# Verify with get_credits too
|
||||
stored_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
stored_balance == POSTGRES_INT_MIN
|
||||
), f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
|
||||
|
||||
# Verify transaction was created with the underflow-protected balance
|
||||
transactions = await CreditTransaction.prisma().find_many(
|
||||
where={"userId": user_id, "type": CreditTransactionType.REFUND},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
assert len(transactions) > 0, "Refund transaction should be created"
|
||||
assert (
|
||||
transactions[0].runningBalance == POSTGRES_INT_MIN
|
||||
), f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServer):
|
||||
"""Test that multiple large refunds applied sequentially don't cause underflow."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"cumulative-underflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold
|
||||
from prisma.models import UserBalance
|
||||
|
||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
)
|
||||
|
||||
# Apply multiple refunds that would cumulatively underflow
|
||||
refund_amount = -300 # Each refund that would cause underflow when cumulative
|
||||
|
||||
# First refund: (POSTGRES_INT_MIN + 500) + (-300) = POSTGRES_INT_MIN + 200 (still above minimum)
|
||||
balance_1, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=refund_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
|
||||
# Should be above minimum for first refund
|
||||
expected_balance_1 = (
|
||||
initial_balance + refund_amount
|
||||
) # Should be POSTGRES_INT_MIN + 200
|
||||
assert (
|
||||
balance_1 == expected_balance_1
|
||||
), f"First refund should result in {expected_balance_1}, got {balance_1}"
|
||||
assert (
|
||||
balance_1 >= POSTGRES_INT_MIN
|
||||
), f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
|
||||
|
||||
# Second refund: (POSTGRES_INT_MIN + 200) + (-300) = POSTGRES_INT_MIN - 100 (would underflow)
|
||||
balance_2, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=refund_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
|
||||
# Should be clamped to minimum due to underflow protection
|
||||
assert (
|
||||
balance_2 == POSTGRES_INT_MIN
|
||||
), f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
|
||||
|
||||
# Third refund: Should stay at minimum
|
||||
balance_3, _ = await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=refund_amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
|
||||
# Should still be at minimum
|
||||
assert (
|
||||
balance_3 == POSTGRES_INT_MIN
|
||||
), f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
|
||||
|
||||
# Final balance check
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == POSTGRES_INT_MIN
|
||||
), f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
||||
"""Test that concurrent large refunds don't cause race condition underflow."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"concurrent-underflow-test-{uuid4()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set up balance close to underflow threshold
|
||||
from prisma.models import UserBalance
|
||||
|
||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||
await UserBalance.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id, "balance": initial_balance},
|
||||
"update": {"balance": initial_balance},
|
||||
},
|
||||
)
|
||||
|
||||
async def large_refund(amount: int, label: str):
|
||||
try:
|
||||
return await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-amount,
|
||||
transaction_type=CreditTransactionType.REFUND,
|
||||
fail_insufficient_credits=False,
|
||||
)
|
||||
except Exception as e:
|
||||
return f"FAILED-{label}: {e}"
|
||||
|
||||
# Run concurrent refunds that would cause underflow if not protected
|
||||
# Each refund of 500 would cause underflow: initial_balance + (-500) could go below POSTGRES_INT_MIN
|
||||
refund_amount = 500
|
||||
results = await asyncio.gather(
|
||||
large_refund(refund_amount, "A"),
|
||||
large_refund(refund_amount, "B"),
|
||||
large_refund(refund_amount, "C"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Check all results are valid and no underflow occurred
|
||||
valid_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, tuple):
|
||||
balance, _ = result
|
||||
assert (
|
||||
balance >= POSTGRES_INT_MIN
|
||||
), f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
|
||||
valid_results.append(balance)
|
||||
elif isinstance(result, str) and "FAILED" in result:
|
||||
# Some operations might fail due to validation, that's okay
|
||||
pass
|
||||
else:
|
||||
# Unexpected exception
|
||||
assert not isinstance(
|
||||
result, Exception
|
||||
), f"Unexpected exception in result {i}: {result}"
|
||||
|
||||
# At least one operation should succeed
|
||||
assert (
|
||||
len(valid_results) > 0
|
||||
), f"At least one refund should succeed, got results: {results}"
|
||||
|
||||
# All successful results should be >= POSTGRES_INT_MIN
|
||||
for balance in valid_results:
|
||||
assert (
|
||||
balance >= POSTGRES_INT_MIN
|
||||
), f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
|
||||
|
||||
# Final balance should be valid and at or above POSTGRES_INT_MIN
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance >= POSTGRES_INT_MIN
|
||||
), f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Integration test to verify complete migration from User.balance to UserBalance table.
|
||||
|
||||
This test ensures that:
|
||||
1. No User.balance queries exist in the system
|
||||
2. All balance operations go through UserBalance table
|
||||
3. User and UserBalance stay synchronized properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User, UserBalance
|
||||
|
||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def create_test_user(user_id: str) -> None:
|
||||
"""Create a test user for migration tests."""
|
||||
try:
|
||||
await User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"test-{user_id}@example.com",
|
||||
"name": f"Test User {user_id[:8]}",
|
||||
}
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# User already exists, continue
|
||||
pass
|
||||
|
||||
|
||||
async def cleanup_test_user(user_id: str) -> None:
|
||||
"""Clean up test user and their data."""
|
||||
try:
|
||||
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
|
||||
await UserBalance.prisma().delete_many(where={"userId": user_id})
|
||||
await User.prisma().delete_many(where={"id": user_id})
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't fail the test
|
||||
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_user_balance_migration_complete(server: SpinTestServer):
|
||||
"""Test that User table balance is never used and UserBalance is source of truth."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"migration-test-{datetime.now().timestamp()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# 1. Verify User table does NOT have balance set initially
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
assert user is not None
|
||||
# User.balance should not exist or should be None/0 if it exists
|
||||
user_balance_attr = getattr(user, "balance", None)
|
||||
if user_balance_attr is not None:
|
||||
assert (
|
||||
user_balance_attr == 0 or user_balance_attr is None
|
||||
), f"User.balance should be 0 or None, got {user_balance_attr}"
|
||||
|
||||
# 2. Perform various credit operations using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "migration_test"}),
|
||||
)
|
||||
balance1 = await credit_system.get_credits(user_id)
|
||||
assert balance1 == 1000
|
||||
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
300,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id="test", reason="Migration test spend"
|
||||
),
|
||||
)
|
||||
balance2 = await credit_system.get_credits(user_id)
|
||||
assert balance2 == 700
|
||||
|
||||
# 3. Verify UserBalance table has correct values
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 700
|
||||
), f"UserBalance should be 700, got {user_balance.balance}"
|
||||
|
||||
# 4. CRITICAL: Verify User.balance is NEVER updated during operations
|
||||
user_after = await User.prisma().find_unique(where={"id": user_id})
|
||||
assert user_after is not None
|
||||
user_balance_after = getattr(user_after, "balance", None)
|
||||
if user_balance_after is not None:
|
||||
# If User.balance exists, it should still be 0 (never updated)
|
||||
assert (
|
||||
user_balance_after == 0 or user_balance_after is None
|
||||
), f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
|
||||
|
||||
# 5. Verify get_credits always returns UserBalance value, not User.balance
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
final_balance == user_balance.balance
|
||||
), f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
||||
"""Test to detect if any operations are still using User.balance instead of UserBalance."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"stale-query-test-{datetime.now().timestamp()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Create UserBalance with specific value
|
||||
await UserBalance.prisma().create(
|
||||
data={"userId": user_id, "balance": 5000} # $50
|
||||
)
|
||||
|
||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||
balance = await credit_system.get_credits(user_id)
|
||||
assert (
|
||||
balance == 5000
|
||||
), f"Expected get_credits to return 5000 from UserBalance, got {balance}"
|
||||
|
||||
# Verify all operations use UserBalance using internal method (bypasses Stripe)
|
||||
await credit_system._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
metadata=SafeJson({"test": "final_verification"}),
|
||||
)
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 6000, f"Expected 6000, got {final_balance}"
|
||||
|
||||
# Verify UserBalance table has the correct value
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 6000
|
||||
), f"UserBalance should be 6000, got {user_balance.balance}"
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer):
|
||||
"""Test that concurrent operations all use UserBalance locking, not User.balance."""
|
||||
credit_system = UserCredit()
|
||||
user_id = f"concurrent-userbalance-test-{datetime.now().timestamp()}"
|
||||
await create_test_user(user_id)
|
||||
|
||||
try:
|
||||
# Set initial balance in UserBalance
|
||||
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
||||
|
||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||
async def concurrent_spend(amount: int, label: str):
|
||||
try:
|
||||
await credit_system.spend_credits(
|
||||
user_id,
|
||||
amount,
|
||||
UsageTransactionMetadata(
|
||||
graph_exec_id=f"concurrent-{label}",
|
||||
reason=f"Concurrent test {label}",
|
||||
),
|
||||
)
|
||||
return f"{label}-SUCCESS"
|
||||
except Exception as e:
|
||||
return f"{label}-FAILED: {e}"
|
||||
|
||||
# Run concurrent operations
|
||||
results = await asyncio.gather(
|
||||
concurrent_spend(100, "A"),
|
||||
concurrent_spend(200, "B"),
|
||||
concurrent_spend(300, "C"),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# All should succeed (1000 >= 100+200+300)
|
||||
successful = [r for r in results if "SUCCESS" in str(r)]
|
||||
assert len(successful) == 3, f"All operations should succeed, got {results}"
|
||||
|
||||
# Final balance should be 1000 - 600 = 400
|
||||
final_balance = await credit_system.get_credits(user_id)
|
||||
assert final_balance == 400, f"Expected final balance 400, got {final_balance}"
|
||||
|
||||
# Verify UserBalance has correct value
|
||||
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
|
||||
assert user_balance is not None
|
||||
assert (
|
||||
user_balance.balance == 400
|
||||
), f"UserBalance should be 400, got {user_balance.balance}"
|
||||
|
||||
# Critical: If User.balance exists and was used, it might have wrong value
|
||||
try:
|
||||
user = await User.prisma().find_unique(where={"id": user_id})
|
||||
user_balance_attr = getattr(user, "balance", None)
|
||||
if user_balance_attr is not None:
|
||||
# If User.balance exists, it should NOT be used for operations
|
||||
# The fact that our final balance is correct from UserBalance proves the system is working
|
||||
print(
|
||||
f"✅ User.balance exists ({user_balance_attr}) but UserBalance ({user_balance.balance}) is being used correctly"
|
||||
)
|
||||
except Exception:
|
||||
print("✅ User.balance column doesn't exist - migration is complete")
|
||||
|
||||
finally:
|
||||
await cleanup_test_user(user_id)
|
||||
@@ -98,42 +98,6 @@ async def transaction(timeout: int = TRANSACTION_TIMEOUT):
|
||||
yield tx
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def locked_transaction(key: str, timeout: int = TRANSACTION_TIMEOUT):
|
||||
"""
|
||||
Create a transaction and take a per-key advisory *transaction* lock.
|
||||
|
||||
- Uses a 64-bit lock id via hashtextextended(key, 0) to avoid 32-bit collisions.
|
||||
- Bound by lock_timeout and statement_timeout so it won't block indefinitely.
|
||||
- Lock is held for the duration of the transaction and auto-released on commit/rollback.
|
||||
|
||||
Args:
|
||||
key: String lock key (e.g., "usr_trx_<uuid>").
|
||||
timeout: Transaction/lock/statement timeout in milliseconds.
|
||||
"""
|
||||
async with transaction(timeout=timeout) as tx:
|
||||
# Ensure we don't wait longer than desired
|
||||
# Note: SET LOCAL doesn't support parameterized queries, must use string interpolation
|
||||
await tx.execute_raw(f"SET LOCAL statement_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
|
||||
await tx.execute_raw(f"SET LOCAL lock_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
|
||||
|
||||
# Block until acquired or lock_timeout hits
|
||||
try:
|
||||
await tx.execute_raw(
|
||||
"SELECT pg_advisory_xact_lock(hashtextextended($1, 0))",
|
||||
key,
|
||||
)
|
||||
except Exception as e:
|
||||
# Normalize PG's lock timeout error to TimeoutError for callers
|
||||
if "lock timeout" in str(e).lower():
|
||||
raise TimeoutError(
|
||||
f"Could not acquire lock for key={key!r} within {timeout}ms"
|
||||
) from e
|
||||
raise
|
||||
|
||||
yield tx
|
||||
|
||||
|
||||
def get_database_schema() -> str:
|
||||
"""Extract database schema from DATABASE_URL."""
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
|
||||
@@ -347,6 +347,9 @@ class APIKeyCredentials(_BaseCredentials):
|
||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||
|
||||
def auth_header(self) -> str:
|
||||
# Linear API keys should not have Bearer prefix
|
||||
if self.provider == "linear":
|
||||
return self.api_key.get_secret_value()
|
||||
return f"Bearer {self.api_key.get_secret_value()}"
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, Optional
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
@@ -13,6 +12,7 @@ from backend.data.block import get_blocks
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.store.model import StoreAgentDetails
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
# Mapping from user reason id to categories to search for when choosing agent to show
|
||||
@@ -26,8 +26,6 @@ REASON_MAPPING: dict[str, list[str]] = {
|
||||
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
|
||||
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
|
||||
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
@@ -147,7 +145,8 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
return
|
||||
|
||||
onboarding.rewardedFor.append(step)
|
||||
await user_credit.onboarding_reward(user_id, reward, step)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
await user_credit_model.onboarding_reward(user_id, reward, step)
|
||||
await UserOnboarding.prisma().update(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
|
||||
5
autogpt_platform/backend/backend/data/partial_types.py
Normal file
5
autogpt_platform/backend/backend/data/partial_types.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import prisma.models
|
||||
|
||||
|
||||
class StoreAgentWithRank(prisma.models.StoreAgent):
|
||||
rank: float
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from autogpt_libs.utils.cache import cached, thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
load_dotenv()
|
||||
@@ -34,7 +34,7 @@ def disconnect():
|
||||
get_redis().close()
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_redis() -> Redis:
|
||||
return connect()
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Optional, cast
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import HTTPException
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User as PrismaUser
|
||||
@@ -16,6 +15,7 @@ from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.util.cache import cached
|
||||
from backend.util.encryption import JSONCryptor
|
||||
from backend.util.exceptions import DatabaseError
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
@@ -39,6 +40,7 @@ from backend.data.notifications import (
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
@@ -56,8 +58,10 @@ from backend.util.service import (
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
config = Config()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -66,23 +70,27 @@ R = TypeVar("R")
|
||||
async def _spend_credits(
|
||||
user_id: str, cost: int, metadata: UsageTransactionMetadata
|
||||
) -> int:
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
return await _user_credit_model.get_credits(user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
super().run_service()
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: "FastAPI"):
|
||||
async with super().lifespan(app):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
await db.connect()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
logger.info(f"[{self.service_name}] ✅ Ready")
|
||||
yield
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
await db.disconnect()
|
||||
|
||||
async def health_check(self) -> str:
|
||||
if not db.is_connected():
|
||||
@@ -145,6 +153,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
@@ -230,6 +239,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
|
||||
@@ -246,7 +246,7 @@ async def execute_node(
|
||||
async for output_name, output_data in node_block.execute(
|
||||
input_data, **extra_exec_kwargs
|
||||
):
|
||||
output_data = json.convert_pydantic_to_json(output_data)
|
||||
output_data = json.to_dict(output_data)
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||||
yield output_name, output_data
|
||||
@@ -1548,11 +1548,12 @@ class ExecutionManager(AppProcess):
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
|
||||
)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
self._execution_locks[graph_exec_id] = cluster_lock
|
||||
|
||||
@@ -1713,6 +1714,8 @@ class ExecutionManager(AppProcess):
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
@@ -248,7 +248,7 @@ class Scheduler(AppService):
|
||||
raise UnhealthyServiceError("Scheduler is still initializing")
|
||||
|
||||
# Check if we're in the middle of cleanup
|
||||
if self.cleaned_up:
|
||||
if self._shutting_down:
|
||||
return await super().health_check()
|
||||
|
||||
# Normal operation - check if scheduler is running
|
||||
@@ -375,7 +375,6 @@ class Scheduler(AppService):
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
if self.scheduler:
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
self.scheduler.shutdown(wait=True)
|
||||
@@ -390,7 +389,7 @@ class Scheduler(AppService):
|
||||
logger.info("⏳ Waiting for event loop thread to finish...")
|
||||
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
logger.info("Scheduler cleanup complete.")
|
||||
super().cleanup()
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
|
||||
@@ -34,6 +34,7 @@ from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.util.cache import cached
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -41,11 +42,12 @@ from backend.util.clients import (
|
||||
get_integration_credentials_store,
|
||||
)
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
@cached(maxsize=1000, ttl_seconds=3600)
|
||||
async def get_user_context(user_id: str) -> UserContext:
|
||||
"""
|
||||
Get UserContext for a user, always returns a valid context with timezone.
|
||||
@@ -53,7 +55,11 @@ async def get_user_context(user_id: str) -> UserContext:
|
||||
"""
|
||||
user_context = UserContext(timezone="UTC") # Default to UTC
|
||||
try:
|
||||
user = await get_user_by_id(user_id)
|
||||
if prisma.is_connected():
|
||||
user = await get_user_by_id(user_id)
|
||||
else:
|
||||
user = await get_database_manager_async_client().get_user_by_id(user_id)
|
||||
|
||||
if user and user.timezone and user.timezone != "not-set":
|
||||
user_context.timezone = user.timezone
|
||||
logger.debug(f"Retrieved user context: timezone={user.timezone}")
|
||||
@@ -93,7 +99,11 @@ class LogMetadata(TruncatedLogger):
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
|
||||
prefix = (
|
||||
"[ExecutionManager]"
|
||||
if is_structured_logging_enabled()
|
||||
else f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]" # noqa
|
||||
)
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from backend.util.cache import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
webhook_managers = {}
|
||||
|
||||
|
||||
@@ -1017,10 +1017,14 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Fatal error in consumer for {queue_name}: {e}")
|
||||
raise
|
||||
|
||||
@continuous_retry()
|
||||
def run_service(self):
|
||||
self.run_and_wait(self._run_service())
|
||||
# Queue the main _run_service task
|
||||
asyncio.run_coroutine_threadsafe(self._run_service(), self.shared_event_loop)
|
||||
|
||||
# Start the main event loop
|
||||
super().run_service()
|
||||
|
||||
@continuous_retry()
|
||||
async def _run_service(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
@@ -1086,10 +1090,11 @@ class NotificationManager(AppService):
|
||||
def cleanup(self):
|
||||
"""Cleanup service resources"""
|
||||
self.running = False
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
logger.info("⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class NotificationManagerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -14,19 +14,49 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "test-user-id"
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "admin-user-id"
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "target-user-id"
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -321,10 +321,6 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_graph(
|
||||
graph_id: str,
|
||||
|
||||
@@ -11,7 +11,6 @@ import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -40,6 +39,7 @@ from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
UserCredit,
|
||||
get_auto_top_up,
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
@@ -84,6 +84,7 @@ from backend.server.model import (
|
||||
UpdateTimezoneRequest,
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.util.cache import cached
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
@@ -107,9 +108,6 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
# Define the API routes
|
||||
v1_router = APIRouter()
|
||||
|
||||
@@ -291,7 +289,7 @@ def _compute_blocks_sync() -> str:
|
||||
return dumps(result)
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
async def _get_cached_blocks() -> str:
|
||||
"""
|
||||
Async cached function with thundering herd protection.
|
||||
@@ -478,7 +476,8 @@ async def upload_file(
|
||||
async def get_user_credits(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> dict[str, int]:
|
||||
return {"credits": await _user_credit_model.get_credits(user_id)}
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return {"credits": await user_credit_model.get_credits(user_id)}
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -490,9 +489,8 @@ async def get_user_credits(
|
||||
async def request_top_up(
|
||||
request: RequestTopUp, user_id: Annotated[str, Security(get_user_id)]
|
||||
):
|
||||
checkout_url = await _user_credit_model.top_up_intent(
|
||||
user_id, request.credit_amount
|
||||
)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
checkout_url = await user_credit_model.top_up_intent(user_id, request.credit_amount)
|
||||
return {"checkout_url": checkout_url}
|
||||
|
||||
|
||||
@@ -507,7 +505,8 @@ async def refund_top_up(
|
||||
transaction_key: str,
|
||||
metadata: dict[str, str],
|
||||
) -> int:
|
||||
return await _user_credit_model.top_up_refund(user_id, transaction_key, metadata)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.top_up_refund(user_id, transaction_key, metadata)
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
@@ -517,7 +516,8 @@ async def refund_top_up(
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
|
||||
await _user_credit_model.fulfill_checkout(user_id=user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
await user_credit_model.fulfill_checkout(user_id=user_id)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -531,18 +531,23 @@ async def configure_user_auto_top_up(
|
||||
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> str:
|
||||
if request.threshold < 0:
|
||||
raise ValueError("Threshold must be greater than 0")
|
||||
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
|
||||
if request.amount < 500 and request.amount != 0:
|
||||
raise ValueError("Amount must be greater than or equal to 500")
|
||||
if request.amount < request.threshold:
|
||||
raise ValueError("Amount must be greater than or equal to threshold")
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Amount must be greater than or equal to 500"
|
||||
)
|
||||
if request.amount != 0 and request.amount < request.threshold:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Amount must be greater than or equal to threshold"
|
||||
)
|
||||
|
||||
current_balance = await _user_credit_model.get_credits(user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
|
||||
if current_balance < request.threshold:
|
||||
await _user_credit_model.top_up_credits(user_id, request.amount)
|
||||
await user_credit_model.top_up_credits(user_id, request.amount)
|
||||
else:
|
||||
await _user_credit_model.top_up_credits(user_id, 0)
|
||||
await user_credit_model.top_up_credits(user_id, 0)
|
||||
|
||||
await set_auto_top_up(
|
||||
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
|
||||
@@ -590,15 +595,13 @@ async def stripe_webhook(request: Request):
|
||||
event["type"] == "checkout.session.completed"
|
||||
or event["type"] == "checkout.session.async_payment_succeeded"
|
||||
):
|
||||
await _user_credit_model.fulfill_checkout(
|
||||
session_id=event["data"]["object"]["id"]
|
||||
)
|
||||
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
|
||||
|
||||
if event["type"] == "charge.dispute.created":
|
||||
await _user_credit_model.handle_dispute(event["data"]["object"])
|
||||
await UserCredit().handle_dispute(event["data"]["object"])
|
||||
|
||||
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
|
||||
await _user_credit_model.deduct_credits(event["data"]["object"])
|
||||
await UserCredit().deduct_credits(event["data"]["object"])
|
||||
|
||||
return Response(status_code=200)
|
||||
|
||||
@@ -612,7 +615,8 @@ async def stripe_webhook(request: Request):
|
||||
async def manage_payment_method(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> dict[str, str]:
|
||||
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return {"url": await user_credit_model.create_billing_portal_session(user_id)}
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -630,7 +634,8 @@ async def get_credit_history(
|
||||
if transaction_count_limit < 1 or transaction_count_limit > 1000:
|
||||
raise ValueError("Transaction count limit must be between 1 and 1000")
|
||||
|
||||
return await _user_credit_model.get_transaction_history(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_transaction_history(
|
||||
user_id=user_id,
|
||||
transaction_time_ceiling=transaction_time,
|
||||
transaction_count_limit=transaction_count_limit,
|
||||
@@ -647,7 +652,8 @@ async def get_credit_history(
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
return await _user_credit_model.get_refund_requests(user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
return await user_credit_model.get_refund_requests(user_id)
|
||||
|
||||
|
||||
########################################################
|
||||
@@ -869,7 +875,8 @@ async def execute_graph(
|
||||
graph_version: Optional[int] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> execution_db.GraphExecutionMeta:
|
||||
current_balance = await _user_credit_model.get_credits(user_id)
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
current_balance = await user_credit_model.get_credits(user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
|
||||
@@ -23,10 +23,13 @@ client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
def setup_app_auth(mock_jwt_user, setup_test_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
# setup_test_user fixture already executed and user is created in database
|
||||
# It returns the user_id which we don't need to await
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
@@ -194,8 +197,12 @@ def test_get_user_credits(
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test get user credits endpoint"""
|
||||
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model.get_credits = AsyncMock(return_value=1000)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
response = client.get("/credits")
|
||||
|
||||
@@ -215,10 +222,14 @@ def test_request_top_up(
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test request top up endpoint"""
|
||||
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model.top_up_intent = AsyncMock(
|
||||
return_value="https://checkout.example.com/session123"
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
request_data = {"credit_amount": 500}
|
||||
|
||||
@@ -261,6 +272,74 @@ def test_get_auto_top_up(
|
||||
)
|
||||
|
||||
|
||||
def test_configure_auto_top_up(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test configure auto top-up endpoint - this test would have caught the enum casting bug"""
|
||||
# Mock the set_auto_top_up function to avoid database operations
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.set_auto_top_up",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
# Mock credit model to avoid Stripe API calls
|
||||
mock_credit_model = mocker.AsyncMock()
|
||||
mock_credit_model.get_credits.return_value = 50 # Current balance below threshold
|
||||
mock_credit_model.top_up_credits.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
# Test data
|
||||
request_data = {
|
||||
"threshold": 100,
|
||||
"amount": 500,
|
||||
}
|
||||
|
||||
response = client.post("/credits/auto-top-up", json=request_data)
|
||||
|
||||
# This should succeed with our fix, but would have failed before with the enum casting error
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "Auto top-up settings updated"
|
||||
|
||||
|
||||
def test_configure_auto_top_up_validation_errors(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test configure auto top-up endpoint validation"""
|
||||
# Mock set_auto_top_up to avoid database operations for successful case
|
||||
mocker.patch("backend.server.routers.v1.set_auto_top_up")
|
||||
|
||||
# Mock credit model to avoid Stripe API calls for the successful case
|
||||
mock_credit_model = mocker.AsyncMock()
|
||||
mock_credit_model.get_credits.return_value = 50
|
||||
mock_credit_model.top_up_credits.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
# Test negative threshold
|
||||
response = client.post(
|
||||
"/credits/auto-top-up", json={"threshold": -1, "amount": 500}
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
# Test amount too small (but not 0)
|
||||
response = client.post(
|
||||
"/credits/auto-top-up", json={"threshold": 100, "amount": 100}
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
# Test amount = 0 (should be allowed)
|
||||
response = client.post("/credits/auto-top-up", json={"threshold": 100, "amount": 0})
|
||||
assert response.status_code == 200 # Should succeed
|
||||
|
||||
|
||||
# Graphs endpoints tests
|
||||
def test_get_graphs(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
|
||||
@@ -11,8 +11,6 @@ from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/admin",
|
||||
@@ -33,7 +31,8 @@ async def add_user_credits(
|
||||
logger.info(
|
||||
f"Admin user {admin_user_id} is adding {amount} credits to user {user_id}"
|
||||
)
|
||||
new_balance, transaction_key = await _user_credit_model._add_transaction(
|
||||
user_credit_model = await get_user_credit_model(user_id)
|
||||
new_balance, transaction_key = await user_credit_model._add_transaction(
|
||||
user_id,
|
||||
amount,
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
@@ -7,12 +7,12 @@ import prisma.enums
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from prisma import Json
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
|
||||
import backend.server.v2.admin.model as admin_model
|
||||
from backend.data.model import UserTransaction
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
@@ -37,12 +37,14 @@ def test_add_user_credits_success(
|
||||
) -> None:
|
||||
"""Test successful credit addition by admin"""
|
||||
# Mock the credit model
|
||||
mock_credit_model = mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
|
||||
)
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model._add_transaction = AsyncMock(
|
||||
return_value=(1500, "transaction-123-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"user_id": target_user_id,
|
||||
@@ -62,11 +64,17 @@ def test_add_user_credits_success(
|
||||
call_args = mock_credit_model._add_transaction.call_args
|
||||
assert call_args[0] == (target_user_id, 500)
|
||||
assert call_args[1]["transaction_type"] == prisma.enums.CreditTransactionType.GRANT
|
||||
# Check that metadata is a Json object with the expected content
|
||||
assert isinstance(call_args[1]["metadata"], Json)
|
||||
assert call_args[1]["metadata"] == Json(
|
||||
{"admin_id": admin_user_id, "reason": "Test credit grant for debugging"}
|
||||
)
|
||||
# Check that metadata is a SafeJson object with the expected content
|
||||
assert isinstance(call_args[1]["metadata"], SafeJson)
|
||||
actual_metadata = call_args[1]["metadata"]
|
||||
expected_data = {
|
||||
"admin_id": admin_user_id,
|
||||
"reason": "Test credit grant for debugging",
|
||||
}
|
||||
|
||||
# SafeJson inherits from Json which stores parsed data in the .data attribute
|
||||
assert actual_metadata.data["admin_id"] == expected_data["admin_id"]
|
||||
assert actual_metadata.data["reason"] == expected_data["reason"]
|
||||
|
||||
# Snapshot test the response
|
||||
configured_snapshot.assert_match(
|
||||
@@ -81,12 +89,14 @@ def test_add_user_credits_negative_amount(
|
||||
) -> None:
|
||||
"""Test credit deduction by admin (negative amount)"""
|
||||
# Mock the credit model
|
||||
mock_credit_model = mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
|
||||
)
|
||||
mock_credit_model = Mock()
|
||||
mock_credit_model._add_transaction = AsyncMock(
|
||||
return_value=(200, "transaction-456-uuid")
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
|
||||
return_value=mock_credit_model,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"user_id": "target-user-id",
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import prisma
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
@@ -18,6 +17,7 @@ from backend.server.v2.builder.model import (
|
||||
ProviderResponse,
|
||||
SearchBlocksResponse,
|
||||
)
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -307,7 +307,7 @@ def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
providers: dict[ProviderName, Provider] = {}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.server.v2.store.db
|
||||
from backend.util.cache import cached
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
@@ -17,7 +16,7 @@ def clear_all_caches():
|
||||
|
||||
# Cache store agents list for 5 minutes
|
||||
# Different cache entries for different query combinations
|
||||
@cached(maxsize=5000, ttl_seconds=300)
|
||||
@cached(maxsize=5000, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
@@ -40,7 +39,7 @@ async def _get_cached_store_agents(
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300)
|
||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
@@ -49,7 +48,7 @@ async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
|
||||
|
||||
# Cache creators list for 5 minutes
|
||||
@cached(maxsize=200, ttl_seconds=300)
|
||||
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
@@ -68,7 +67,7 @@ async def _get_cached_store_creators(
|
||||
|
||||
|
||||
# Cache individual creator details for 5 minutes
|
||||
@cached(maxsize=100, ttl_seconds=300)
|
||||
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
@@ -71,64 +72,199 @@ async def get_store_agents(
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
search_term = sanitize_query(search_query)
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
|
||||
sanitized_creators = []
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
for c in creators:
|
||||
sanitized_creators.append(sanitize_query(c))
|
||||
|
||||
sanitized_category = None
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
if search_term:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_term, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_term, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == "runs":
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
sanitized_category = sanitize_query(category)
|
||||
|
||||
try:
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
# If search_query is provided, use full-text search
|
||||
if search_query:
|
||||
search_term = sanitize_query(search_query)
|
||||
if not search_term:
|
||||
# Return empty results for invalid search query
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
# Add to the list only if creation was successful
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
# You could log the error here if needed
|
||||
logger.error(
|
||||
f"Error parsing Store agent when getting store agents from db: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Whitelist allowed order_by columns
|
||||
ALLOWED_ORDER_BY = {
|
||||
"rating": "rating DESC, rank DESC",
|
||||
"runs": "runs DESC, rank DESC",
|
||||
"name": "agent_name ASC, rank DESC",
|
||||
"updated_at": "updated_at DESC, rank DESC",
|
||||
}
|
||||
|
||||
# Validate and get order clause
|
||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||
else:
|
||||
order_by_clause = "updated_at DESC, rank DESC",
|
||||
|
||||
# Build WHERE conditions and parameters list
|
||||
where_parts: list[str] = []
|
||||
params: list[typing.Any] = [search_term] # $1 - search term
|
||||
param_index = 2 # Start at $2 for next parameter
|
||||
|
||||
# Always filter for available agents
|
||||
where_parts.append("is_available = true")
|
||||
|
||||
if featured:
|
||||
where_parts.append("featured = true")
|
||||
|
||||
if creators and sanitized_creators:
|
||||
# Use ANY with array parameter
|
||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||
params.append(sanitized_creators)
|
||||
param_index += 1
|
||||
|
||||
if category and sanitized_category:
|
||||
where_parts.append(f"${param_index} = ANY(categories)")
|
||||
params.append(sanitized_category)
|
||||
param_index += 1
|
||||
|
||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
limit_param = f"${param_index}"
|
||||
offset_param = f"${param_index + 1}"
|
||||
|
||||
# Execute full-text search query with parameterized values
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
ts_rank_cd(search, query) AS rank
|
||||
FROM "StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Count query for pagination - only uses search term parameter
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
"""
|
||||
|
||||
# Execute both queries with parameters
|
||||
agents = await prisma.client.get_client().query_raw(
|
||||
typing.cast(typing.LiteralString, sql_query), *params
|
||||
)
|
||||
|
||||
# For count, use params without pagination (last 2 params)
|
||||
count_params = params[:-2]
|
||||
count_result = await prisma.client.get_client().query_raw(
|
||||
typing.cast(typing.LiteralString, count_query), *count_params
|
||||
)
|
||||
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Store agent from search results: {e}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# Non-search query path (original logic)
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": sanitized_creators}
|
||||
if sanitized_category:
|
||||
where_clause["categories"] = {"has": sanitized_category}
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == "runs":
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
)
|
||||
# Add to the list only if creation was successful
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
# You could log the error here if needed
|
||||
logger.error(
|
||||
f"Error parsing Store agent when getting store agents from db: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
|
||||
@@ -20,7 +20,7 @@ async def setup_prisma():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents(mocker):
|
||||
# Mock data
|
||||
mock_agents = [
|
||||
@@ -64,7 +64,7 @@ async def test_get_store_agents(mocker):
|
||||
mock_store_agent.return_value.count.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
@@ -173,7 +173,7 @@ async def test_get_store_agent_details(mocker):
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_creator_details(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
@@ -210,7 +210,7 @@ async def test_get_store_creator_details(mocker):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_create_store_submission(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
@@ -282,7 +282,7 @@ async def test_create_store_submission(mocker):
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_update_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
@@ -327,7 +327,7 @@ async def test_update_profile(mocker):
|
||||
mock_profile_db.return_value.update.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_user_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
@@ -359,3 +359,63 @@ async def test_get_user_profile(mocker):
|
||||
assert result.description == "Test description"
|
||||
assert result.links == ["link1", "link2"]
|
||||
assert result.avatar_url == "avatar.jpg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_with_search_parameterized(mocker):
|
||||
"""Test that search query uses parameterized SQL - validates the fix works"""
|
||||
|
||||
# Call function with search query containing potential SQL injection
|
||||
malicious_search = "test'; DROP TABLE StoreAgent; --"
|
||||
result = await db.get_store_agents(search_query=malicious_search)
|
||||
|
||||
# Verify query executed safely
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_with_search_and_filters_parameterized():
|
||||
"""Test parameterized SQL with multiple filters"""
|
||||
|
||||
# Call with multiple filters including potential injection attempts
|
||||
result = await db.get_store_agents(
|
||||
search_query="test",
|
||||
creators=["creator1'; DROP TABLE Users; --", "creator2"],
|
||||
category="AI'; DELETE FROM StoreAgent; --",
|
||||
featured=True,
|
||||
sorted_by="rating",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the query executed without error
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_search_with_invalid_sort_by():
|
||||
"""Test that invalid sorted_by value doesn't cause SQL injection""" # Try to inject SQL via sorted_by parameter
|
||||
malicious_sort = "rating; DROP TABLE Users; --"
|
||||
result = await db.get_store_agents(
|
||||
search_query="test",
|
||||
sorted_by=malicious_sort,
|
||||
)
|
||||
|
||||
# Verify the query executed without error
|
||||
# Invalid sort_by should fall back to default, not cause SQL injection
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_store_agents_search_category_array_injection():
|
||||
"""Test that category parameter is safely passed as a parameter"""
|
||||
# Try SQL injection via category
|
||||
malicious_category = "AI'; DROP TABLE StoreAgent; --"
|
||||
result = await db.get_store_agents(
|
||||
search_query="test",
|
||||
category=malicious_category,
|
||||
)
|
||||
|
||||
# Verify the query executed without error
|
||||
# Category should be parameterized, preventing SQL injection
|
||||
assert isinstance(result.agents, list)
|
||||
|
||||
@@ -40,23 +40,13 @@ async def get_profile(
|
||||
Get the profile details for the authenticated user.
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
try:
|
||||
profile = await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "Profile not found"},
|
||||
)
|
||||
return profile
|
||||
except Exception as e:
|
||||
logger.exception("Failed to fetch user profile for %s: %s", user_id, e)
|
||||
profile = await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "Failed to retrieve user profile",
|
||||
"hint": "Check database connection.",
|
||||
},
|
||||
status_code=404,
|
||||
content={"detail": "Profile not found"},
|
||||
)
|
||||
return profile
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -83,20 +73,10 @@ async def update_or_create_profile(
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
try:
|
||||
updated_profile = await backend.server.v2.store.db.update_profile(
|
||||
user_id=user_id, profile=profile
|
||||
)
|
||||
return updated_profile
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update profile for user %s: %s", user_id, e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "Failed to update user profile",
|
||||
"hint": "Validate request data.",
|
||||
},
|
||||
)
|
||||
updated_profile = await backend.server.v2.store.db.update_profile(
|
||||
user_id=user_id, profile=profile
|
||||
)
|
||||
return updated_profile
|
||||
|
||||
|
||||
##############################################
|
||||
@@ -155,26 +135,16 @@ async def get_agents(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
try:
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
except Exception as e:
|
||||
logger.exception("Failed to retrieve store agents: %s", e)
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "Failed to retrieve store agents",
|
||||
"hint": "Check database or search parameters.",
|
||||
},
|
||||
)
|
||||
agents = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -189,22 +159,13 @@ async def get_agent(username: str, agent_name: str):
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
agent = await store_cache._get_cached_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store agent details")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving the store agent details"
|
||||
},
|
||||
)
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
agent = await store_cache._get_cached_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -217,17 +178,10 @@ async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: s
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
"""
|
||||
try:
|
||||
graph = await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
return graph
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting agent graph")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the agent graph"},
|
||||
)
|
||||
graph = await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
return graph
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -241,18 +195,11 @@ async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
"""
|
||||
try:
|
||||
agent = await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
agent = await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
return agent
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store agent")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the store agent"},
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -280,24 +227,17 @@ async def create_review(
|
||||
Returns:
|
||||
The created review
|
||||
"""
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
created_review = await backend.server.v2.store.db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
created_review = await backend.server.v2.store.db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
return created_review
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store review")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while creating the store review"},
|
||||
)
|
||||
return created_review
|
||||
|
||||
|
||||
##############################################
|
||||
@@ -340,21 +280,14 @@ async def get_creators(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
try:
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return creators
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store creators")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the store creators"},
|
||||
)
|
||||
creators = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return creators
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -370,18 +303,9 @@ async def get_creator(
|
||||
Get the details of a creator.
|
||||
- Creator Details Page
|
||||
"""
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting creator details")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving the creator details"
|
||||
},
|
||||
)
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
return creator
|
||||
|
||||
|
||||
############################################
|
||||
@@ -404,17 +328,10 @@ async def get_my_agents(
|
||||
"""
|
||||
Get user's own agents.
|
||||
"""
|
||||
try:
|
||||
agents = await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting my agents")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the my agents"},
|
||||
)
|
||||
agents = await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -438,19 +355,12 @@ async def delete_submission(
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await backend.server.v2.store.db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
result = await backend.server.v2.store.db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst deleting store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while deleting the store submission"},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -488,21 +398,12 @@ async def get_submissions(
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
try:
|
||||
listings = await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store submissions")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving the store submissions"
|
||||
},
|
||||
)
|
||||
listings = await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -529,36 +430,23 @@ async def create_submission(
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
try:
|
||||
result = await backend.server.v2.store.db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
instructions=submission_request.instructions,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
result = await backend.server.v2.store.db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
instructions=submission_request.instructions,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
return result
|
||||
except backend.server.v2.store.exceptions.SlugAlreadyInUseError as e:
|
||||
logger.warning("Slug already in use: %s", str(e))
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=409,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while creating the store submission"},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -627,36 +515,10 @@ async def upload_submission_media(
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
try:
|
||||
media_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=file
|
||||
)
|
||||
return media_url
|
||||
except backend.server.v2.store.exceptions.VirusDetectedError as e:
|
||||
logger.warning(f"Virus detected in uploaded file: {e.threat_name}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"detail": f"File rejected due to virus detection: {e.threat_name}",
|
||||
"error_type": "virus_detected",
|
||||
"threat_name": e.threat_name,
|
||||
},
|
||||
)
|
||||
except backend.server.v2.store.exceptions.VirusScanError as e:
|
||||
logger.error(f"Virus scanning failed: {str(e)}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"detail": "Virus scanning service unavailable. Please try again later.",
|
||||
"error_type": "virus_scan_failed",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst uploading submission media")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while uploading the media file"},
|
||||
)
|
||||
media_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=file
|
||||
)
|
||||
return media_url
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -679,44 +541,35 @@ async def generate_image(
|
||||
Returns:
|
||||
JSONResponse: JSON containing the URL of the generated image
|
||||
"""
|
||||
try:
|
||||
agent = await backend.data.graph.get_graph(agent_id, user_id=user_id)
|
||||
agent = await backend.data.graph.get_graph(agent_id, user_id=user_id)
|
||||
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
|
||||
existing_url = await backend.server.v2.store.media.check_media_exists(
|
||||
user_id, filename
|
||||
)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
# Generate agent image as JPEG
|
||||
image = await backend.server.v2.store.image_gen.generate_agent_image(
|
||||
agent=agent
|
||||
)
|
||||
existing_url = await backend.server.v2.store.media.check_media_exists(
|
||||
user_id, filename
|
||||
)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
# Generate agent image as JPEG
|
||||
image = await backend.server.v2.store.image_gen.generate_agent_image(agent=agent)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
image_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
image_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst generating submission image")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while generating the image"},
|
||||
)
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -329,7 +329,3 @@ class WebsocketServer(AppProcess):
|
||||
port=Config().websocket_server_port,
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")
|
||||
|
||||
457
autogpt_platform/backend/backend/util/cache.py
Normal file
457
autogpt_platform/backend/backend/util/cache.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
Caching utilities for the AutoGPT platform.
|
||||
|
||||
Provides decorators for caching function results with support for:
|
||||
- In-memory caching with TTL
|
||||
- Shared Redis-backed caching across processes
|
||||
- Thread-local caching for request-scoped data
|
||||
- Thundering herd protection
|
||||
- LRU eviction with optional TTL refresh
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, ParamSpec, Protocol, TypeVar, cast, runtime_checkable
|
||||
|
||||
from redis import ConnectionPool, Redis
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# RECOMMENDED REDIS CONFIGURATION FOR PRODUCTION:
|
||||
# Configure Redis with the following settings for optimal caching performance:
|
||||
# maxmemory-policy allkeys-lru # Evict least recently used keys when memory limit reached
|
||||
# maxmemory 2gb # Set memory limit (adjust based on your needs)
|
||||
# save "" # Disable persistence if using Redis purely for caching
|
||||
|
||||
# Create a dedicated Redis connection pool for caching (binary mode for pickle)
|
||||
_cache_pool: ConnectionPool | None = None
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring cache connection pool")
|
||||
def _get_cache_pool() -> ConnectionPool:
|
||||
"""Get or create a connection pool for cache operations."""
|
||||
global _cache_pool
|
||||
if _cache_pool is None:
|
||||
_cache_pool = ConnectionPool(
|
||||
host=settings.config.redis_host,
|
||||
port=settings.config.redis_port,
|
||||
password=settings.config.redis_password or None,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
max_connections=50,
|
||||
socket_keepalive=True,
|
||||
socket_connect_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _cache_pool
|
||||
|
||||
|
||||
redis = Redis(connection_pool=_get_cache_pool())
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedValue:
|
||||
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
|
||||
|
||||
result: Any
|
||||
timestamp: float
|
||||
|
||||
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Convert args and kwargs into a hashable cache key.
|
||||
|
||||
Handles unhashable types like dict, list, set by converting them to
|
||||
their sorted string representations.
|
||||
"""
|
||||
|
||||
def make_hashable(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a hashable representation."""
|
||||
if isinstance(obj, dict):
|
||||
# Sort dict items to ensure consistent ordering
|
||||
return (
|
||||
"__dict__",
|
||||
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
|
||||
)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return ("__list__", tuple(make_hashable(item) for item in obj))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
|
||||
else:
|
||||
# For basic hashable types (str, int, bool, None, etc.)
|
||||
try:
|
||||
hash(obj)
|
||||
return obj
|
||||
except TypeError:
|
||||
# Fallback: convert to string representation
|
||||
return ("__str__", str(obj))
|
||||
|
||||
hashable_args = tuple(make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
|
||||
return (hashable_args, hashable_kwargs)
|
||||
|
||||
|
||||
def _make_redis_key(key: tuple[Any, ...], func_name: str) -> str:
|
||||
"""Convert a hashable key tuple to a Redis key string."""
|
||||
# Ensure key is already hashable
|
||||
hashable_key = key if isinstance(key, tuple) else (key,)
|
||||
return f"cache:{func_name}:{hash(hashable_key)}"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for cached functions with cache management methods."""
|
||||
|
||||
def cache_clear(self, pattern: str | None = None) -> None:
|
||||
"""Clear cached entries. If pattern provided, clear matching entries only."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
|
||||
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
|
||||
return False
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int,
|
||||
shared_cache: bool = False,
|
||||
refresh_ttl_on_get: bool = False,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries (only for in-memory cache)
|
||||
ttl_seconds: Time to live in seconds. Required - entries must expire.
|
||||
shared_cache: If True, use Redis for cross-process caching
|
||||
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
|
||||
|
||||
Returns:
|
||||
Decorated function with caching capabilities
|
||||
|
||||
Example:
|
||||
@cached(ttl_seconds=300) # 5 minute TTL
|
||||
def expensive_sync_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(target_func):
|
||||
cache_storage: dict[tuple, CachedValue] = {}
|
||||
_event_loop_locks: dict[Any, asyncio.Lock] = {}
|
||||
|
||||
def _get_from_redis(redis_key: str) -> Any | None:
|
||||
"""Get value from Redis, optionally refreshing TTL."""
|
||||
try:
|
||||
if refresh_ttl_on_get:
|
||||
# Use GETEX to get value and refresh expiry atomically
|
||||
cached_bytes = redis.getex(redis_key, ex=ttl_seconds)
|
||||
else:
|
||||
cached_bytes = redis.get(redis_key)
|
||||
|
||||
if cached_bytes and isinstance(cached_bytes, bytes):
|
||||
return pickle.loads(cached_bytes)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error during cache check for {target_func.__name__}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _set_to_redis(redis_key: str, value: Any) -> None:
|
||||
"""Set value in Redis with TTL."""
|
||||
try:
|
||||
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
redis.setex(redis_key, ttl_seconds, pickled_value)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error storing cache for {target_func.__name__}: {e}"
|
||||
)
|
||||
|
||||
def _get_from_memory(key: tuple) -> Any | None:
|
||||
"""Get value from in-memory cache, checking TTL."""
|
||||
if key in cache_storage:
|
||||
cached_data = cache_storage[key]
|
||||
if time.time() - cached_data.timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {target_func.__name__} args: {key[0]} kwargs: {key[1]}"
|
||||
)
|
||||
return cached_data.result
|
||||
return None
|
||||
|
||||
def _set_to_memory(key: tuple, value: Any) -> None:
|
||||
"""Set value in in-memory cache with timestamp."""
|
||||
cache_storage[key] = CachedValue(result=value, timestamp=time.time())
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
|
||||
def _get_cache_lock():
|
||||
"""Get or create an asyncio.Lock for the current event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop not in _event_loop_locks:
|
||||
return _event_loop_locks.setdefault(loop, asyncio.Lock())
|
||||
return _event_loop_locks[loop]
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with _get_cache_lock():
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = async_wrapper
|
||||
|
||||
else:
|
||||
# Sync function with threading.Lock
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if shared_cache:
|
||||
result = _get_from_redis(redis_key)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
result = _get_from_memory(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if shared_cache:
|
||||
_set_to_redis(redis_key, result)
|
||||
else:
|
||||
_set_to_memory(key, result)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Add cache management methods
|
||||
def cache_clear(pattern: str | None = None) -> None:
|
||||
"""Clear cache entries. If pattern provided, clear matching entries."""
|
||||
if shared_cache:
|
||||
if pattern:
|
||||
# Clear entries matching pattern
|
||||
keys = list(
|
||||
redis.scan_iter(f"cache:{target_func.__name__}:{pattern}")
|
||||
)
|
||||
else:
|
||||
# Clear all cache keys
|
||||
keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
|
||||
|
||||
if keys:
|
||||
pipeline = redis.pipeline()
|
||||
for key in keys:
|
||||
pipeline.delete(key)
|
||||
pipeline.execute()
|
||||
else:
|
||||
if pattern:
|
||||
# For in-memory cache, pattern matching not supported
|
||||
logger.warning(
|
||||
"Pattern-based clearing not supported for in-memory cache"
|
||||
)
|
||||
else:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
if shared_cache:
|
||||
cache_keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
|
||||
return {
|
||||
"size": len(cache_keys),
|
||||
"maxsize": None, # Redis manages its own size
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if shared_cache:
|
||||
redis_key = _make_redis_key(key, target_func.__name__)
|
||||
if redis.exists(redis_key):
|
||||
redis.delete(redis_key)
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
setattr(wrapper, "cache_delete", cache_delete)
|
||||
|
||||
return cast(CachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
|
||||
Each thread gets its own cache, which is useful for request-scoped caching
|
||||
in web applications where you want to cache within a single request but
|
||||
not across requests.
|
||||
|
||||
Args:
|
||||
func: The function to cache
|
||||
|
||||
Returns:
|
||||
Decorated function with thread-local caching
|
||||
|
||||
Example:
|
||||
@thread_cached
|
||||
def expensive_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@thread_cached # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = await func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
"""Clear thread-local cache for a function."""
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
@@ -16,7 +16,7 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import cached, clear_thread_cache, thread_cached
|
||||
from backend.util.cache import cached, clear_thread_cache, thread_cached
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -332,7 +332,7 @@ class TestCache:
|
||||
"""Test basic sync caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
def expensive_sync_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -358,7 +358,7 @@ class TestCache:
|
||||
"""Test basic async caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -385,7 +385,7 @@ class TestCache:
|
||||
call_count = 0
|
||||
results = []
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
def slow_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -412,7 +412,7 @@ class TestCache:
|
||||
"""Test that concurrent async calls don't cause thundering herd."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def slow_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -508,7 +508,7 @@ class TestCache:
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -537,7 +537,7 @@ class TestCache:
|
||||
"""Test cache clearing functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def async_clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -567,7 +567,7 @@ class TestCache:
|
||||
"""Test that cached async functions return actual results, not coroutines."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def async_result_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -593,7 +593,7 @@ class TestCache:
|
||||
"""Test selective cache deletion functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
def deletable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -636,7 +636,7 @@ class TestCache:
|
||||
"""Test selective cache deletion functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=300)
|
||||
async def async_deletable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
@@ -674,3 +674,450 @@ class TestCache:
|
||||
# Try to delete non-existent entry
|
||||
was_deleted = async_deletable_function.cache_delete(99)
|
||||
assert was_deleted is False
|
||||
|
||||
|
||||
class TestSharedCache:
|
||||
"""Tests for shared_cache (Redis-backed) functionality."""
|
||||
|
||||
def test_sync_shared_cache_basic(self):
|
||||
"""Test basic shared cache functionality with sync function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def shared_sync_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
# Clear any existing cache
|
||||
shared_sync_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = shared_sync_function(10, 20)
|
||||
assert result1 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use Redis cache
|
||||
result2 = shared_sync_function(10, 20)
|
||||
assert result2 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = shared_sync_function(15, 25)
|
||||
assert result3 == 40
|
||||
assert call_count == 2
|
||||
|
||||
# Cleanup
|
||||
shared_sync_function.cache_clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_basic(self):
|
||||
"""Test basic shared cache functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def shared_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x + y
|
||||
|
||||
# Clear any existing cache
|
||||
shared_async_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = await shared_async_function(10, 20)
|
||||
assert result1 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use Redis cache
|
||||
result2 = await shared_async_function(10, 20)
|
||||
assert result2 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await shared_async_function(15, 25)
|
||||
assert result3 == 40
|
||||
assert call_count == 2
|
||||
|
||||
# Cleanup
|
||||
shared_async_function.cache_clear()
|
||||
|
||||
def test_shared_cache_ttl_refresh(self):
|
||||
"""Test TTL refresh functionality with shared cache."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=True)
|
||||
def ttl_refresh_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 10
|
||||
|
||||
# Clear any existing cache
|
||||
ttl_refresh_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = ttl_refresh_function(3)
|
||||
assert result1 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Wait 1 second
|
||||
time.sleep(1)
|
||||
|
||||
# Second call - should refresh TTL and use cache
|
||||
result2 = ttl_refresh_function(3)
|
||||
assert result2 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Wait another 1.5 seconds (total 2.5s from first call, 1.5s from second)
|
||||
time.sleep(1.5)
|
||||
|
||||
# Third call - TTL should have been refreshed, so still cached
|
||||
result3 = ttl_refresh_function(3)
|
||||
assert result3 == 30
|
||||
assert call_count == 1
|
||||
|
||||
# Wait 2.1 seconds - now it should expire
|
||||
time.sleep(2.1)
|
||||
|
||||
# Fourth call - should call function again
|
||||
result4 = ttl_refresh_function(3)
|
||||
assert result4 == 30
|
||||
assert call_count == 2
|
||||
|
||||
# Cleanup
|
||||
ttl_refresh_function.cache_clear()
|
||||
|
||||
def test_shared_cache_without_ttl_refresh(self):
|
||||
"""Test that TTL doesn't refresh when refresh_ttl_on_get=False."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=False)
|
||||
def no_ttl_refresh_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 10
|
||||
|
||||
# Clear any existing cache
|
||||
no_ttl_refresh_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = no_ttl_refresh_function(4)
|
||||
assert result1 == 40
|
||||
assert call_count == 1
|
||||
|
||||
# Wait 1 second
|
||||
time.sleep(1)
|
||||
|
||||
# Second call - should use cache but NOT refresh TTL
|
||||
result2 = no_ttl_refresh_function(4)
|
||||
assert result2 == 40
|
||||
assert call_count == 1
|
||||
|
||||
# Wait another 1.1 seconds (total 2.1s from first call)
|
||||
time.sleep(1.1)
|
||||
|
||||
# Third call - should have expired
|
||||
result3 = no_ttl_refresh_function(4)
|
||||
assert result3 == 40
|
||||
assert call_count == 2
|
||||
|
||||
# Cleanup
|
||||
no_ttl_refresh_function.cache_clear()
|
||||
|
||||
def test_shared_cache_complex_objects(self):
|
||||
"""Test caching complex objects with shared cache (pickle serialization)."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def complex_object_function(x: int) -> dict:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return {
|
||||
"number": x,
|
||||
"squared": x**2,
|
||||
"nested": {"list": [1, 2, x], "tuple": (x, x * 2)},
|
||||
"string": f"value_{x}",
|
||||
}
|
||||
|
||||
# Clear any existing cache
|
||||
complex_object_function.cache_clear()
|
||||
|
||||
# First call
|
||||
result1 = complex_object_function(5)
|
||||
assert result1["number"] == 5
|
||||
assert result1["squared"] == 25
|
||||
assert result1["nested"]["list"] == [1, 2, 5]
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = complex_object_function(5)
|
||||
assert result2 == result1
|
||||
assert call_count == 1
|
||||
|
||||
# Cleanup
|
||||
complex_object_function.cache_clear()
|
||||
|
||||
def test_shared_cache_info(self):
|
||||
"""Test cache_info for shared cache."""
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def info_shared_function(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
# Clear any existing cache
|
||||
info_shared_function.cache_clear()
|
||||
|
||||
# Check initial info
|
||||
info = info_shared_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] is None # Redis manages size
|
||||
assert info["ttl_seconds"] == 30
|
||||
|
||||
# Add some entries
|
||||
info_shared_function(1)
|
||||
info_shared_function(2)
|
||||
info_shared_function(3)
|
||||
|
||||
info = info_shared_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Cleanup
|
||||
info_shared_function.cache_clear()
|
||||
|
||||
def test_shared_cache_delete(self):
|
||||
"""Test selective deletion with shared cache."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def delete_shared_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 3
|
||||
|
||||
# Clear any existing cache
|
||||
delete_shared_function.cache_clear()
|
||||
|
||||
# Add entries
|
||||
delete_shared_function(1)
|
||||
delete_shared_function(2)
|
||||
delete_shared_function(3)
|
||||
assert call_count == 3
|
||||
|
||||
# Verify cached
|
||||
delete_shared_function(1)
|
||||
delete_shared_function(2)
|
||||
assert call_count == 3
|
||||
|
||||
# Delete specific entry
|
||||
was_deleted = delete_shared_function.cache_delete(2)
|
||||
assert was_deleted is True
|
||||
|
||||
# Entry for x=2 should be gone
|
||||
delete_shared_function(2)
|
||||
assert call_count == 4
|
||||
|
||||
# Others should still be cached
|
||||
delete_shared_function(1)
|
||||
delete_shared_function(3)
|
||||
assert call_count == 4
|
||||
|
||||
# Try to delete non-existent
|
||||
was_deleted = delete_shared_function.cache_delete(99)
|
||||
assert was_deleted is False
|
||||
|
||||
# Cleanup
|
||||
delete_shared_function.cache_clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_shared_cache_thundering_herd(self):
|
||||
"""Test that shared cache prevents thundering herd for async functions."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def shared_slow_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.1)
|
||||
return x * x
|
||||
|
||||
# Clear any existing cache
|
||||
shared_slow_function.cache_clear()
|
||||
|
||||
# Launch multiple concurrent tasks
|
||||
tasks = [shared_slow_function(8) for _ in range(10)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All should return same result
|
||||
assert all(r == 64 for r in results)
|
||||
# Only one should have executed
|
||||
assert call_count == 1
|
||||
|
||||
# Cleanup
|
||||
shared_slow_function.cache_clear()
|
||||
|
||||
def test_shared_cache_clear_pattern(self):
|
||||
"""Test pattern-based cache clearing (Redis feature)."""
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def pattern_function(category: str, item: int) -> str:
|
||||
return f"{category}_{item}"
|
||||
|
||||
# Clear any existing cache
|
||||
pattern_function.cache_clear()
|
||||
|
||||
# Add various entries
|
||||
pattern_function("fruit", 1)
|
||||
pattern_function("fruit", 2)
|
||||
pattern_function("vegetable", 1)
|
||||
pattern_function("vegetable", 2)
|
||||
|
||||
info = pattern_function.cache_info()
|
||||
assert info["size"] == 4
|
||||
|
||||
# Note: Pattern clearing with wildcards requires specific Redis scan
|
||||
# implementation. The current code clears by pattern but needs
|
||||
# adjustment for partial matching. For now, test full clear.
|
||||
pattern_function.cache_clear()
|
||||
info = pattern_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
|
||||
def test_shared_vs_local_cache_isolation(self):
|
||||
"""Test that shared and local caches are isolated."""
|
||||
shared_count = 0
|
||||
local_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def shared_function(x: int) -> int:
|
||||
nonlocal shared_count
|
||||
shared_count += 1
|
||||
return x * 2
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=False)
|
||||
def local_function(x: int) -> int:
|
||||
nonlocal local_count
|
||||
local_count += 1
|
||||
return x * 2
|
||||
|
||||
# Clear caches
|
||||
shared_function.cache_clear()
|
||||
local_function.cache_clear()
|
||||
|
||||
# Call both with same args
|
||||
shared_result = shared_function(5)
|
||||
local_result = local_function(5)
|
||||
|
||||
assert shared_result == local_result == 10
|
||||
assert shared_count == 1
|
||||
assert local_count == 1
|
||||
|
||||
# Call again - both should use their respective caches
|
||||
shared_function(5)
|
||||
local_function(5)
|
||||
assert shared_count == 1
|
||||
assert local_count == 1
|
||||
|
||||
# Clear only shared cache
|
||||
shared_function.cache_clear()
|
||||
|
||||
# Shared should recompute, local should still use cache
|
||||
shared_function(5)
|
||||
local_function(5)
|
||||
assert shared_count == 2
|
||||
assert local_count == 1
|
||||
|
||||
# Cleanup
|
||||
shared_function.cache_clear()
|
||||
local_function.cache_clear()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_cache_concurrent_different_keys(self):
|
||||
"""Test that concurrent calls with different keys work correctly."""
|
||||
call_counts = {}
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def multi_key_function(key: str) -> str:
|
||||
if key not in call_counts:
|
||||
call_counts[key] = 0
|
||||
call_counts[key] += 1
|
||||
await asyncio.sleep(0.05)
|
||||
return f"result_{key}"
|
||||
|
||||
# Clear cache
|
||||
multi_key_function.cache_clear()
|
||||
|
||||
# Launch concurrent tasks with different keys
|
||||
keys = ["a", "b", "c", "d", "e"]
|
||||
tasks = []
|
||||
for key in keys:
|
||||
# Multiple calls per key
|
||||
tasks.extend([multi_key_function(key) for _ in range(3)])
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify results
|
||||
for i, key in enumerate(keys):
|
||||
expected = f"result_{key}"
|
||||
# Each key appears 3 times in results
|
||||
key_results = results[i * 3 : (i + 1) * 3]
|
||||
assert all(r == expected for r in key_results)
|
||||
|
||||
# Each key should only be computed once
|
||||
for key in keys:
|
||||
assert call_counts[key] == 1
|
||||
|
||||
# Cleanup
|
||||
multi_key_function.cache_clear()
|
||||
|
||||
def test_shared_cache_performance_comparison(self):
|
||||
"""Compare performance of shared vs local cache."""
|
||||
import statistics
|
||||
|
||||
shared_times = []
|
||||
local_times = []
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def shared_perf_function(x: int) -> int:
|
||||
time.sleep(0.01) # Simulate work
|
||||
return x * 2
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=False)
|
||||
def local_perf_function(x: int) -> int:
|
||||
time.sleep(0.01) # Simulate work
|
||||
return x * 2
|
||||
|
||||
# Clear caches
|
||||
shared_perf_function.cache_clear()
|
||||
local_perf_function.cache_clear()
|
||||
|
||||
# Warm up both caches
|
||||
for i in range(5):
|
||||
shared_perf_function(i)
|
||||
local_perf_function(i)
|
||||
|
||||
# Measure cache hit times
|
||||
for i in range(5):
|
||||
# Shared cache hit
|
||||
start = time.time()
|
||||
shared_perf_function(i)
|
||||
shared_times.append(time.time() - start)
|
||||
|
||||
# Local cache hit
|
||||
start = time.time()
|
||||
local_perf_function(i)
|
||||
local_times.append(time.time() - start)
|
||||
|
||||
# Local cache should be faster (no Redis round-trip)
|
||||
avg_shared = statistics.mean(shared_times)
|
||||
avg_local = statistics.mean(local_times)
|
||||
|
||||
print(f"Avg shared cache hit time: {avg_shared:.6f}s")
|
||||
print(f"Avg local cache hit time: {avg_local:.6f}s")
|
||||
|
||||
# Local should be significantly faster for cache hits
|
||||
# Redis adds network latency even for cache hits
|
||||
assert avg_local < avg_shared
|
||||
|
||||
# Cleanup
|
||||
shared_perf_function.cache_clear()
|
||||
local_perf_function.cache_clear()
|
||||
@@ -4,8 +4,7 @@ Centralized service client helpers with thread caching.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.cache import cached, thread_cached
|
||||
|
||||
from backend.util.cache import cached, thread_cached
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
@@ -120,7 +119,7 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
# ============ Supabase Clients ============ #
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_supabase() -> "Client":
|
||||
"""Get a process-cached synchronous Supabase client instance."""
|
||||
from supabase import create_client
|
||||
@@ -130,7 +129,7 @@ def get_supabase() -> "Client":
|
||||
)
|
||||
|
||||
|
||||
@cached()
|
||||
@cached(ttl_seconds=3600)
|
||||
async def get_async_supabase() -> "AClient":
|
||||
"""Get a process-cached asynchronous Supabase client instance."""
|
||||
from supabase import create_async_client
|
||||
|
||||
@@ -5,12 +5,12 @@ from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
import ldclient
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from backend.util.cache import cached
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,6 +35,7 @@ class Flag(str, Enum):
|
||||
AI_ACTIVITY_STATUS = "ai-agent-execution-summary"
|
||||
BETA_BLOCKS = "beta-blocks"
|
||||
AGENT_ACTIVITY = "agent-activity"
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
@@ -62,9 +63,9 @@ def initialize_launchdarkly() -> None:
|
||||
config = Config(sdk_key)
|
||||
ldclient.set_config(config)
|
||||
|
||||
global _is_initialized
|
||||
_is_initialized = True
|
||||
if ldclient.get().is_initialized():
|
||||
global _is_initialized
|
||||
_is_initialized = True
|
||||
logger.info("LaunchDarkly client initialized successfully")
|
||||
else:
|
||||
logger.error("LaunchDarkly client failed to initialize")
|
||||
@@ -217,7 +218,8 @@ def feature_flag(
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
"LaunchDarkly not initialized, "
|
||||
f"using default {flag_key}={repr(default)}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
@@ -231,8 +233,9 @@ def feature_flag(
|
||||
else:
|
||||
# Log warning and use default for non-boolean values
|
||||
logger.warning(
|
||||
f"Feature flag {flag_key} returned non-boolean value: {flag_value} (type: {type(flag_value).__name__}). "
|
||||
f"Using default={default}"
|
||||
f"Feature flag {flag_key} returned non-boolean value: "
|
||||
f"{repr(flag_value)} (type: {type(flag_value).__name__}). "
|
||||
f"Using default value {repr(default)}"
|
||||
)
|
||||
is_enabled = default
|
||||
|
||||
|
||||
@@ -1,35 +1,22 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Type, TypeGuard, TypeVar, overload
|
||||
from typing import Any, Type, TypeVar, overload
|
||||
|
||||
import jsonschema
|
||||
import orjson
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.encoders import jsonable_encoder as to_dict
|
||||
from prisma import Json
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .truncate import truncate
|
||||
from .type import type_match
|
||||
|
||||
__all__ = [
|
||||
"json",
|
||||
"dumps",
|
||||
"loads",
|
||||
"validate_with_jsonschema",
|
||||
"SafeJson",
|
||||
"convert_pydantic_to_json",
|
||||
]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Precompiled regex to remove PostgreSQL-incompatible control characters
|
||||
# Removes \u0000-\u0008, \u000B-\u000C, \u000E-\u001F, \u007F (keeps tab \u0009, newline \u000A, carriage return \u000D)
|
||||
POSTGRES_CONTROL_CHARS = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]")
|
||||
|
||||
|
||||
def to_dict(data) -> dict:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
return jsonable_encoder(data)
|
||||
|
||||
|
||||
def dumps(
|
||||
data: Any, *args: Any, indent: int | None = None, option: int = 0, **kwargs: Any
|
||||
) -> str:
|
||||
@@ -118,64 +105,46 @@ def validate_with_jsonschema(
|
||||
return str(e)
|
||||
|
||||
|
||||
def is_list_of_basemodels(value: object) -> TypeGuard[list[BaseModel]]:
|
||||
return isinstance(value, list) and all(
|
||||
isinstance(item, BaseModel) for item in value
|
||||
)
|
||||
def _sanitize_string(value: str) -> str:
|
||||
"""Remove PostgreSQL-incompatible control characters from string."""
|
||||
return POSTGRES_CONTROL_CHARS.sub("", value)
|
||||
|
||||
|
||||
def convert_pydantic_to_json(output_data: Any) -> Any:
|
||||
if isinstance(output_data, BaseModel):
|
||||
return output_data.model_dump()
|
||||
if is_list_of_basemodels(output_data):
|
||||
return [item.model_dump() for item in output_data]
|
||||
return output_data
|
||||
def sanitize_json(data: Any) -> Any:
|
||||
try:
|
||||
# Use two-pass approach for consistent string sanitization:
|
||||
# 1. First convert to basic JSON-serializable types (handles Pydantic models)
|
||||
# 2. Then sanitize strings in the result
|
||||
basic_result = to_dict(data)
|
||||
return to_dict(basic_result, custom_encoder={str: _sanitize_string})
|
||||
except Exception as e:
|
||||
# Log the failure and fall back to string representation
|
||||
logger.error(
|
||||
"SafeJson fallback to string representation due to serialization error: %s (%s). "
|
||||
"Data type: %s, Data preview: %s",
|
||||
type(e).__name__,
|
||||
truncate(str(e), 200),
|
||||
type(data).__name__,
|
||||
truncate(str(data), 100),
|
||||
)
|
||||
|
||||
# Ultimate fallback: convert to string representation and sanitize
|
||||
return _sanitize_string(str(data))
|
||||
|
||||
|
||||
def _sanitize_value(value: Any) -> Any:
|
||||
"""
|
||||
Recursively sanitize values by removing PostgreSQL-incompatible control characters.
|
||||
|
||||
This function walks through data structures and removes control characters from strings.
|
||||
It handles:
|
||||
- Strings: Remove control chars directly from the string
|
||||
- Lists: Recursively sanitize each element
|
||||
- Dicts: Recursively sanitize keys and values
|
||||
- Other types: Return as-is
|
||||
|
||||
Args:
|
||||
value: The value to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized version of the value with control characters removed
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
# Remove control characters directly from the string
|
||||
return POSTGRES_CONTROL_CHARS.sub("", value)
|
||||
elif isinstance(value, dict):
|
||||
# Recursively sanitize dictionary keys and values
|
||||
return {_sanitize_value(k): _sanitize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
# Recursively sanitize list elements
|
||||
return [_sanitize_value(item) for item in value]
|
||||
elif isinstance(value, tuple):
|
||||
# Recursively sanitize tuple elements
|
||||
return tuple(_sanitize_value(item) for item in value)
|
||||
else:
|
||||
# For other types (int, float, bool, None, etc.), return as-is
|
||||
return value
|
||||
|
||||
|
||||
def SafeJson(data: Any) -> Json:
|
||||
class SafeJson(Json):
|
||||
"""
|
||||
Safely serialize data and return Prisma's Json type.
|
||||
Sanitizes control characters to prevent PostgreSQL 22P05 errors.
|
||||
|
||||
This function:
|
||||
1. Converts Pydantic models to dicts
|
||||
1. Converts Pydantic models to dicts (recursively using to_dict)
|
||||
2. Recursively removes PostgreSQL-incompatible control characters from strings
|
||||
3. Returns a Prisma Json object safe for database storage
|
||||
|
||||
Uses to_dict (jsonable_encoder) with a custom encoder to handle both Pydantic
|
||||
conversion and control character sanitization in a two-pass approach.
|
||||
|
||||
Args:
|
||||
data: Input data to sanitize and convert to Json
|
||||
|
||||
@@ -187,12 +156,6 @@ def SafeJson(data: Any) -> Json:
|
||||
>>> SafeJson({"path": "C:\\\\temp"}) # backslashes preserved
|
||||
>>> SafeJson({"data": "Text\\\\u0000here"}) # literal backslash-u preserved
|
||||
"""
|
||||
# Convert Pydantic models to dict first
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump(exclude_none=True)
|
||||
|
||||
# Sanitize the data structure by removing control characters
|
||||
sanitized_data = _sanitize_value(data)
|
||||
|
||||
# Return as Prisma Json type
|
||||
return Json(sanitized_data)
|
||||
def __init__(self, data: Any):
|
||||
super().__init__(sanitize_json(data))
|
||||
|
||||
@@ -8,10 +8,7 @@ settings = Settings()
|
||||
def configure_logging():
|
||||
import autogpt_libs.logging.config
|
||||
|
||||
if (
|
||||
settings.config.behave_as == BehaveAs.LOCAL
|
||||
or settings.config.app_env == AppEnvironment.LOCAL
|
||||
):
|
||||
if not is_structured_logging_enabled():
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=False)
|
||||
else:
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=True)
|
||||
@@ -20,6 +17,14 @@ def configure_logging():
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def is_structured_logging_enabled() -> bool:
|
||||
"""Check if structured logging (cloud logging) is enabled."""
|
||||
return not (
|
||||
settings.config.behave_as == BehaveAs.LOCAL
|
||||
or settings.config.app_env == AppEnvironment.LOCAL
|
||||
)
|
||||
|
||||
|
||||
class TruncatedLogger:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,15 +3,17 @@ from enum import Enum
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import SecretStr
|
||||
from sentry_sdk.integrations import DidNotEnable
|
||||
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
||||
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
||||
from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
from backend.util.feature_flag import get_client, is_configured
|
||||
from backend.util import feature_flag
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscordChannel(str, Enum):
|
||||
@@ -22,8 +24,11 @@ class DiscordChannel(str, Enum):
|
||||
def sentry_init():
|
||||
sentry_dsn = settings.secrets.sentry_dsn
|
||||
integrations = []
|
||||
if is_configured():
|
||||
integrations.append(LaunchDarklyIntegration(get_client()))
|
||||
if feature_flag.is_configured():
|
||||
try:
|
||||
integrations.append(LaunchDarklyIntegration(feature_flag.get_client()))
|
||||
except DidNotEnable as e:
|
||||
logger.error(f"Error enabling LaunchDarklyIntegration for Sentry: {e}")
|
||||
sentry_sdk.init(
|
||||
dsn=sentry_dsn,
|
||||
traces_sample_rate=1.0,
|
||||
|
||||
@@ -8,18 +8,9 @@ from typing import Optional
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.settings import set_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_SERVICE_NAME = "MainProcess"
|
||||
|
||||
|
||||
def get_service_name():
|
||||
return _SERVICE_NAME
|
||||
|
||||
|
||||
def set_service_name(name: str):
|
||||
global _SERVICE_NAME
|
||||
_SERVICE_NAME = name
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
@@ -28,7 +19,8 @@ class AppProcess(ABC):
|
||||
"""
|
||||
|
||||
process: Optional[Process] = None
|
||||
cleaned_up = False
|
||||
_shutting_down: bool = False
|
||||
_cleaned_up: bool = False
|
||||
|
||||
if "forkserver" in get_all_start_methods():
|
||||
set_start_method("forkserver", force=True)
|
||||
@@ -52,7 +44,6 @@ class AppProcess(ABC):
|
||||
def service_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
@@ -74,7 +65,8 @@ class AppProcess(ABC):
|
||||
self.run()
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Termination request: {type(e).__name__}; {e} executing cleanup."
|
||||
f"[{self.service_name}] 🛑 Terminating because of {type(e).__name__}: {e}", # noqa
|
||||
exc_info=e if not isinstance(e, SystemExit) else None,
|
||||
)
|
||||
# Send error to Sentry before cleanup
|
||||
if not isinstance(e, (KeyboardInterrupt, SystemExit)):
|
||||
@@ -85,8 +77,12 @@ class AppProcess(ABC):
|
||||
except Exception:
|
||||
pass # Silently ignore if Sentry isn't available
|
||||
finally:
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] Terminated.")
|
||||
if not self._cleaned_up:
|
||||
self._cleaned_up = True
|
||||
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] ✅ Cleanup done")
|
||||
logger.info(f"[{self.service_name}] 🛑 Terminated")
|
||||
|
||||
@staticmethod
|
||||
def llprint(message: str):
|
||||
@@ -97,8 +93,8 @@ class AppProcess(ABC):
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
if not self.cleaned_up:
|
||||
self.cleaned_up = True
|
||||
if not self._shutting_down:
|
||||
self._shutting_down = True
|
||||
sys.exit(0)
|
||||
else:
|
||||
self.llprint(
|
||||
|
||||
@@ -13,7 +13,7 @@ import idna
|
||||
from aiohttp import FormData, abc
|
||||
from tenacity import retry, retry_if_result, wait_exponential_jitter
|
||||
|
||||
from backend.util.json import json
|
||||
from backend.util.json import loads
|
||||
|
||||
# Retry status codes for which we will automatically retry the request
|
||||
THROTTLE_RETRY_STATUS_CODES: set[int] = {429, 500, 502, 503, 504, 408}
|
||||
@@ -175,10 +175,15 @@ async def validate_url(
|
||||
f"for hostname {ascii_hostname} is not allowed."
|
||||
)
|
||||
|
||||
# Reconstruct the netloc with IDNA-encoded hostname and preserve port
|
||||
netloc = ascii_hostname
|
||||
if parsed.port:
|
||||
netloc = f"{ascii_hostname}:{parsed.port}"
|
||||
|
||||
return (
|
||||
URL(
|
||||
parsed.scheme,
|
||||
ascii_hostname,
|
||||
netloc,
|
||||
quote(parsed.path, safe="/%:@"),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
@@ -259,7 +264,7 @@ class Response:
|
||||
"""
|
||||
Parse the body as JSON and return the resulting Python object.
|
||||
"""
|
||||
return json.loads(
|
||||
return loads(
|
||||
self.content.decode(encoding or "utf-8", errors="replace"), **kwargs
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from backend.util.process import get_service_name
|
||||
from backend.util.settings import get_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,9 +4,12 @@ import concurrent.futures
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -31,9 +34,9 @@ import backend.util.exceptions as exceptions
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import conn_retry, create_retry_decorator
|
||||
from backend.util.settings import Config
|
||||
from backend.util.settings import Config, get_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
@@ -111,14 +114,44 @@ class BaseAppService(AppProcess, ABC):
|
||||
return target_host
|
||||
|
||||
def run_service(self) -> None:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
# HACK: run the main event loop outside the main thread to disable Uvicorn's
|
||||
# internal signal handlers, since there is no config option for this :(
|
||||
shared_asyncio_thread = threading.Thread(
|
||||
target=self._run_shared_event_loop,
|
||||
daemon=True,
|
||||
name=f"{self.service_name}-shared-event-loop",
|
||||
)
|
||||
shared_asyncio_thread.start()
|
||||
shared_asyncio_thread.join()
|
||||
|
||||
def _run_shared_event_loop(self) -> None:
|
||||
try:
|
||||
self.shared_event_loop.run_forever()
|
||||
finally:
|
||||
logger.info(f"[{self.service_name}] 🛑 Shared event loop stopped")
|
||||
self.shared_event_loop.close() # ensure held resources are released
|
||||
|
||||
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop).result()
|
||||
|
||||
def run(self):
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
self.shared_event_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.shared_event_loop)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
**💡 Overriding `AppService.lifespan` may be a more convenient option.**
|
||||
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
e.g. disconnecting from a database or terminating child processes.
|
||||
|
||||
**Note:** if you override this method in a subclass, it must call
|
||||
`super().cleanup()` *at the end*!
|
||||
"""
|
||||
# Stop the shared event loop to allow resource clean-up
|
||||
self.shared_event_loop.call_soon_threadsafe(self.shared_event_loop.stop)
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class RemoteCallError(BaseModel):
|
||||
@@ -179,6 +212,7 @@ EXCEPTION_MAPPING = {
|
||||
|
||||
class AppService(BaseAppService, ABC):
|
||||
fastapi_app: FastAPI
|
||||
http_server: uvicorn.Server | None = None
|
||||
log_level: str = "info"
|
||||
|
||||
def set_log_level(self, log_level: str):
|
||||
@@ -190,11 +224,10 @@ class AppService(BaseAppService, ABC):
|
||||
def _handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: Request, exc: Exception):
|
||||
if log_error:
|
||||
if status_code == 500:
|
||||
log = logger.exception
|
||||
else:
|
||||
log = logger.error
|
||||
log(f"{request.method} {request.url.path} failed: {exc}")
|
||||
logger.error(
|
||||
f"{request.method} {request.url.path} failed: {exc}",
|
||||
exc_info=exc if status_code == 500 else None,
|
||||
)
|
||||
return responses.JSONResponse(
|
||||
status_code=status_code,
|
||||
content=RemoteCallError(
|
||||
@@ -256,13 +289,13 @@ class AppService(BaseAppService, ABC):
|
||||
|
||||
return sync_endpoint
|
||||
|
||||
@conn_retry("FastAPI server", "Starting FastAPI server")
|
||||
@conn_retry("FastAPI server", "Running FastAPI server")
|
||||
def __start_fastapi(self):
|
||||
logger.info(
|
||||
f"[{self.service_name}] Starting RPC server at http://{api_host}:{self.get_port()}"
|
||||
)
|
||||
|
||||
server = uvicorn.Server(
|
||||
self.http_server = uvicorn.Server(
|
||||
uvicorn.Config(
|
||||
self.fastapi_app,
|
||||
host=api_host,
|
||||
@@ -271,18 +304,76 @@ class AppService(BaseAppService, ABC):
|
||||
log_level=self.log_level,
|
||||
)
|
||||
)
|
||||
self.shared_event_loop.run_until_complete(server.serve())
|
||||
self.run_and_wait(self.http_server.serve())
|
||||
|
||||
# Perform clean-up when the server exits
|
||||
if not self._cleaned_up:
|
||||
self._cleaned_up = True
|
||||
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] ✅ Cleanup done")
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
"""Pass SIGTERM to Uvicorn so it can shut down gracefully"""
|
||||
signame = signal.Signals(signum).name
|
||||
if not self._shutting_down:
|
||||
self._shutting_down = True
|
||||
if self.http_server:
|
||||
logger.info(
|
||||
f"[{self.service_name}] 🛑 Received {signame} ({signum}) - "
|
||||
"Entering RPC server graceful shutdown"
|
||||
)
|
||||
self.http_server.handle_exit(signum, frame) # stop accepting requests
|
||||
|
||||
# NOTE: Actually stopping the process is triggered by:
|
||||
# 1. The call to self.cleanup() at the end of __start_fastapi() 👆🏼
|
||||
# 2. BaseAppService.cleanup() stopping the shared event loop
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] {signame} received before HTTP server init."
|
||||
" Terminating..."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
else:
|
||||
# Expedite shutdown on second SIGTERM
|
||||
logger.info(
|
||||
f"[{self.service_name}] 🛑🛑 Received {signame} ({signum}), "
|
||||
"but shutdown is already underway. Terminating..."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: FastAPI):
|
||||
"""
|
||||
The FastAPI/Uvicorn server's lifespan manager, used for setup and shutdown.
|
||||
|
||||
You can extend and use this in a subclass like:
|
||||
```
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: FastAPI):
|
||||
async with super().lifespan(app):
|
||||
await db.connect()
|
||||
yield
|
||||
await db.disconnect()
|
||||
```
|
||||
"""
|
||||
# Startup - this runs before Uvicorn starts accepting connections
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown - this runs when FastAPI/Uvicorn shuts down
|
||||
logger.info(f"[{self.service_name}] ✅ FastAPI has finished")
|
||||
|
||||
async def health_check(self) -> str:
|
||||
"""
|
||||
A method to check the health of the process.
|
||||
"""
|
||||
"""A method to check the health of the process."""
|
||||
return "OK"
|
||||
|
||||
def run(self):
|
||||
sentry_init()
|
||||
super().run()
|
||||
self.fastapi_app = FastAPI()
|
||||
|
||||
self.fastapi_app = FastAPI(lifespan=self.lifespan)
|
||||
|
||||
# Add Prometheus instrumentation to all services
|
||||
try:
|
||||
@@ -325,7 +416,11 @@ class AppService(BaseAppService, ABC):
|
||||
)
|
||||
|
||||
# Start the FastAPI server in a separate thread.
|
||||
api_thread = threading.Thread(target=self.__start_fastapi, daemon=True)
|
||||
api_thread = threading.Thread(
|
||||
target=self.__start_fastapi,
|
||||
daemon=True,
|
||||
name=f"{self.service_name}-http-server",
|
||||
)
|
||||
api_thread.start()
|
||||
|
||||
# Run the main service loop (blocking).
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import time
|
||||
from functools import cached_property
|
||||
from unittest.mock import Mock
|
||||
@@ -18,20 +20,11 @@ from backend.util.service import (
|
||||
TEST_SERVICE_PORT = 8765
|
||||
|
||||
|
||||
def wait_for_service_ready(service_client_type, timeout_seconds=30):
|
||||
"""Helper method to wait for a service to be ready using health check with retry."""
|
||||
client = get_service_client(service_client_type, request_retry=True)
|
||||
client.health_check() # This will retry until service is ready
|
||||
|
||||
|
||||
class ServiceTest(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fail_count = 0
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return TEST_SERVICE_PORT
|
||||
@@ -41,10 +34,17 @@ class ServiceTest(AppService):
|
||||
result = super().__enter__()
|
||||
|
||||
# Wait for the service to be ready
|
||||
wait_for_service_ready(ServiceTestClient)
|
||||
self.wait_until_ready()
|
||||
|
||||
return result
|
||||
|
||||
def wait_until_ready(self, timeout_seconds: int = 5):
|
||||
"""Helper method to wait for a service to be ready using health check with retry."""
|
||||
client = get_service_client(
|
||||
ServiceTestClient, call_timeout=timeout_seconds, request_retry=True
|
||||
)
|
||||
client.health_check() # This will retry until service is ready\
|
||||
|
||||
@expose
|
||||
def add(self, a: int, b: int) -> int:
|
||||
return a + b
|
||||
@@ -490,3 +490,167 @@ class TestHTTPErrorRetryBehavior:
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status_code
|
||||
|
||||
|
||||
class TestGracefulShutdownService(AppService):
|
||||
"""Test service with slow endpoints for testing graceful shutdown"""
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return 18999 # Use a specific test port
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.request_log = []
|
||||
self.cleanup_called = False
|
||||
self.cleanup_completed = False
|
||||
|
||||
@expose
|
||||
async def slow_endpoint(self, duration: int = 5) -> dict:
|
||||
"""Endpoint that takes time to complete"""
|
||||
start_time = time.time()
|
||||
self.request_log.append(f"slow_endpoint started at {start_time}")
|
||||
|
||||
await asyncio.sleep(duration)
|
||||
|
||||
end_time = time.time()
|
||||
result = {
|
||||
"message": "completed",
|
||||
"duration": end_time - start_time,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
self.request_log.append(f"slow_endpoint completed at {end_time}")
|
||||
return result
|
||||
|
||||
@expose
|
||||
def fast_endpoint(self) -> dict:
|
||||
"""Fast endpoint for testing rejection during shutdown"""
|
||||
timestamp = time.time()
|
||||
self.request_log.append(f"fast_endpoint called at {timestamp}")
|
||||
return {"message": "fast", "timestamp": timestamp}
|
||||
|
||||
def cleanup(self):
|
||||
"""Override cleanup to track when it's called"""
|
||||
self.cleanup_called = True
|
||||
self.request_log.append(f"cleanup started at {time.time()}")
|
||||
|
||||
# Call parent cleanup
|
||||
super().cleanup()
|
||||
|
||||
self.cleanup_completed = True
|
||||
self.request_log.append(f"cleanup completed at {time.time()}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def test_service():
|
||||
"""Run the test service in a separate process"""
|
||||
|
||||
service = TestGracefulShutdownService()
|
||||
service.start(background=True)
|
||||
|
||||
base_url = f"http://localhost:{service.get_port()}"
|
||||
|
||||
await wait_until_service_ready(base_url)
|
||||
yield service, base_url
|
||||
|
||||
service.stop()
|
||||
|
||||
|
||||
async def wait_until_service_ready(base_url: str, timeout: float = 10):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time <= timeout:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
with contextlib.suppress(httpx.ConnectError):
|
||||
response = await client.get(f"{base_url}/health_check", timeout=5)
|
||||
|
||||
if response.status_code == 200 and response.json() == "OK":
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
raise RuntimeError(f"Service at {base_url} not available after {timeout} seconds")
|
||||
|
||||
|
||||
async def send_slow_request(base_url: str) -> dict:
|
||||
"""Send a slow request and return the result"""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.post(f"{base_url}/slow_endpoint", json={"duration": 5})
|
||||
assert response.status_code == 200
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_shutdown(test_service):
|
||||
"""Test that AppService handles graceful shutdown correctly"""
|
||||
service, test_service_url = test_service
|
||||
|
||||
# Start a slow request that should complete even after shutdown
|
||||
slow_task = asyncio.create_task(send_slow_request(test_service_url))
|
||||
|
||||
# Give the slow request time to start
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Send SIGTERM to the service process
|
||||
shutdown_start_time = time.time()
|
||||
service.process.terminate() # This sends SIGTERM
|
||||
|
||||
# Wait a moment for shutdown to start
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Try to send a new request - should be rejected or connection refused
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.post(f"{test_service_url}/fast_endpoint", json={})
|
||||
# Should get 503 Service Unavailable during shutdown
|
||||
assert response.status_code == 503
|
||||
assert "shutting down" in response.json()["detail"].lower()
|
||||
except httpx.ConnectError:
|
||||
# Connection refused is also acceptable - server stopped accepting
|
||||
pass
|
||||
|
||||
# The slow request should still complete successfully
|
||||
slow_result = await slow_task
|
||||
assert slow_result["message"] == "completed"
|
||||
assert 4.9 < slow_result["duration"] < 5.5 # Should have taken ~5 seconds
|
||||
|
||||
# Wait for the service to fully shut down
|
||||
service.process.join(timeout=15)
|
||||
shutdown_end_time = time.time()
|
||||
|
||||
# Verify the service actually terminated
|
||||
assert not service.process.is_alive()
|
||||
|
||||
# Verify shutdown took reasonable time (slow request - 1s + cleanup)
|
||||
shutdown_duration = shutdown_end_time - shutdown_start_time
|
||||
assert 4 <= shutdown_duration <= 6 # ~5s request - 1s + buffer
|
||||
|
||||
print(f"Shutdown took {shutdown_duration:.2f} seconds")
|
||||
print(f"Slow request completed in: {slow_result['duration']:.2f} seconds")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_during_shutdown(test_service):
|
||||
"""Test that health checks behave correctly during shutdown"""
|
||||
service, test_service_url = test_service
|
||||
|
||||
# Health check should pass initially
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get(f"{test_service_url}/health_check")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Send SIGTERM
|
||||
service.process.terminate()
|
||||
|
||||
# Wait for shutdown to begin
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Health check should now fail or connection should be refused
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get(f"{test_service_url}/health_check")
|
||||
# Could either get 503, 500 (unhealthy), or connection error
|
||||
assert response.status_code in [500, 503]
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout):
|
||||
# Connection refused/timeout is also acceptable
|
||||
pass
|
||||
|
||||
@@ -15,6 +15,17 @@ from backend.util.data import get_data_path
|
||||
|
||||
T = TypeVar("T", bound=BaseSettings)
|
||||
|
||||
_SERVICE_NAME = "MainProcess"
|
||||
|
||||
|
||||
def get_service_name():
|
||||
return _SERVICE_NAME
|
||||
|
||||
|
||||
def set_service_name(name: str):
|
||||
global _SERVICE_NAME
|
||||
_SERVICE_NAME = name
|
||||
|
||||
|
||||
class AppEnvironment(str, Enum):
|
||||
LOCAL = "local"
|
||||
@@ -254,6 +265,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default="localhost",
|
||||
description="The host for the RabbitMQ server",
|
||||
)
|
||||
|
||||
rabbitmq_port: int = Field(
|
||||
default=5672,
|
||||
description="The port for the RabbitMQ server",
|
||||
@@ -264,6 +276,21 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The vhost for the RabbitMQ server",
|
||||
)
|
||||
|
||||
redis_host: str = Field(
|
||||
default="localhost",
|
||||
description="The host for the Redis server",
|
||||
)
|
||||
|
||||
redis_port: int = Field(
|
||||
default=6379,
|
||||
description="The port for the Redis server",
|
||||
)
|
||||
|
||||
redis_password: str = Field(
|
||||
default="",
|
||||
description="The password for the Redis server (empty string if no password)",
|
||||
)
|
||||
|
||||
postmark_sender_email: str = Field(
|
||||
default="invalid@invalid.com",
|
||||
description="The email address to use for sending emails",
|
||||
|
||||
@@ -411,3 +411,346 @@ class TestSafeJson:
|
||||
assert "C:\\temp\\file" in str(file_path_with_null)
|
||||
assert ".txt" in str(file_path_with_null)
|
||||
assert "\x00" not in str(file_path_with_null) # Null removed from path
|
||||
|
||||
def test_invalid_escape_error_prevention(self):
|
||||
"""Test that SafeJson prevents 'Invalid \\escape' errors that occurred in upsert_execution_output."""
|
||||
# This reproduces the exact scenario that was causing the error:
|
||||
# POST /upsert_execution_output failed: Invalid \escape: line 1 column 36404 (char 36403)
|
||||
|
||||
# Create data with various problematic escape sequences that could cause JSON parsing errors
|
||||
problematic_output_data = {
|
||||
"web_content": "Article text\x00with null\x01and control\x08chars\x0C\x1F\x7F",
|
||||
"file_path": "C:\\Users\\test\\file\x00.txt",
|
||||
"json_like_string": '{"text": "data\x00\x08\x1F"}',
|
||||
"escaped_sequences": "Text with \\u0000 and \\u0008 sequences",
|
||||
"mixed_content": "Normal text\tproperly\nformatted\rwith\x00invalid\x08chars\x1Fmixed",
|
||||
"large_text": "A" * 35000
|
||||
+ "\x00\x08\x1F"
|
||||
+ "B" * 5000, # Large text like in the error
|
||||
}
|
||||
|
||||
# This should not raise any JSON parsing errors
|
||||
result = SafeJson(problematic_output_data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Verify the result is a valid Json object that can be safely stored in PostgreSQL
|
||||
result_data = cast(dict[str, Any], result.data)
|
||||
assert isinstance(result_data, dict)
|
||||
|
||||
# Verify problematic characters are removed but safe content preserved
|
||||
web_content = result_data.get("web_content", "")
|
||||
file_path = result_data.get("file_path", "")
|
||||
large_text = result_data.get("large_text", "")
|
||||
|
||||
# Check that control characters are removed
|
||||
assert "\x00" not in str(web_content)
|
||||
assert "\x01" not in str(web_content)
|
||||
assert "\x08" not in str(web_content)
|
||||
assert "\x0C" not in str(web_content)
|
||||
assert "\x1F" not in str(web_content)
|
||||
assert "\x7F" not in str(web_content)
|
||||
|
||||
# Check that legitimate content is preserved
|
||||
assert "Article text" in str(web_content)
|
||||
assert "with null" in str(web_content)
|
||||
assert "and control" in str(web_content)
|
||||
assert "chars" in str(web_content)
|
||||
|
||||
# Check file path handling
|
||||
assert "C:\\Users\\test\\file" in str(file_path)
|
||||
assert ".txt" in str(file_path)
|
||||
assert "\x00" not in str(file_path)
|
||||
|
||||
# Check large text handling (the scenario from the error at char 36403)
|
||||
assert len(str(large_text)) > 35000 # Content preserved
|
||||
assert "A" * 1000 in str(large_text) # A's preserved
|
||||
assert "B" * 1000 in str(large_text) # B's preserved
|
||||
assert "\x00" not in str(large_text) # Control chars removed
|
||||
assert "\x08" not in str(large_text)
|
||||
assert "\x1F" not in str(large_text)
|
||||
|
||||
# Most importantly: ensure the result can be JSON-serialized without errors
|
||||
# This would have failed with the old approach
|
||||
import json
|
||||
|
||||
json_string = json.dumps(result.data) # Should not raise "Invalid \escape"
|
||||
assert len(json_string) > 0
|
||||
|
||||
# And can be parsed back
|
||||
parsed_back = json.loads(json_string)
|
||||
assert isinstance(parsed_back, dict)
|
||||
|
||||
def test_dict_containing_pydantic_models(self):
|
||||
"""Test that dicts containing Pydantic models are properly serialized."""
|
||||
# This reproduces the bug from PR #11187 where credential_inputs failed
|
||||
model1 = SamplePydanticModel(name="Alice", age=30)
|
||||
model2 = SamplePydanticModel(name="Bob", age=25)
|
||||
|
||||
data = {
|
||||
"user1": model1,
|
||||
"user2": model2,
|
||||
"regular_data": "test",
|
||||
}
|
||||
|
||||
result = SafeJson(data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Verify it can be JSON serialized (this was the bug)
|
||||
import json
|
||||
|
||||
json_string = json.dumps(result.data)
|
||||
assert "Alice" in json_string
|
||||
assert "Bob" in json_string
|
||||
|
||||
def test_nested_pydantic_in_dict(self):
|
||||
"""Test deeply nested Pydantic models in dicts."""
|
||||
inner_model = SamplePydanticModel(name="Inner", age=20)
|
||||
middle_model = SamplePydanticModel(
|
||||
name="Middle", age=30, metadata={"inner": inner_model}
|
||||
)
|
||||
|
||||
data = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"model": middle_model,
|
||||
"other": "data",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = SafeJson(data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
import json
|
||||
|
||||
json_string = json.dumps(result.data)
|
||||
assert "Middle" in json_string
|
||||
assert "Inner" in json_string
|
||||
|
||||
def test_list_containing_pydantic_models_in_dict(self):
|
||||
"""Test list of Pydantic models inside a dict."""
|
||||
models = [SamplePydanticModel(name=f"User{i}", age=20 + i) for i in range(5)]
|
||||
|
||||
data = {
|
||||
"users": models,
|
||||
"count": len(models),
|
||||
}
|
||||
|
||||
result = SafeJson(data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
import json
|
||||
|
||||
json_string = json.dumps(result.data)
|
||||
assert "User0" in json_string
|
||||
assert "User4" in json_string
|
||||
|
||||
def test_credentials_meta_input_scenario(self):
|
||||
"""Test the exact scenario from create_graph_execution that was failing."""
|
||||
|
||||
# Simulate CredentialsMetaInput structure
|
||||
class MockCredentialsMetaInput(BaseModel):
|
||||
id: str
|
||||
title: Optional[str] = None
|
||||
provider: str
|
||||
type: str
|
||||
|
||||
cred_input = MockCredentialsMetaInput(
|
||||
id="test-123", title="Test Credentials", provider="github", type="oauth2"
|
||||
)
|
||||
|
||||
# This is how credential_inputs is structured in create_graph_execution
|
||||
credential_inputs = {"github_creds": cred_input}
|
||||
|
||||
# This should work without TypeError
|
||||
result = SafeJson(credential_inputs)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Verify it can be JSON serialized
|
||||
import json
|
||||
|
||||
json_string = json.dumps(result.data)
|
||||
assert "test-123" in json_string
|
||||
assert "github" in json_string
|
||||
assert "oauth2" in json_string
|
||||
|
||||
def test_mixed_pydantic_and_primitives(self):
|
||||
"""Test complex mix of Pydantic models and primitive types."""
|
||||
model = SamplePydanticModel(name="Test", age=25)
|
||||
|
||||
data = {
|
||||
"models": [model, {"plain": "dict"}, "string", 123],
|
||||
"nested": {
|
||||
"model": model,
|
||||
"list": [1, 2, model, 4],
|
||||
"plain": "text",
|
||||
},
|
||||
"plain_list": [1, 2, 3],
|
||||
}
|
||||
|
||||
result = SafeJson(data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
import json
|
||||
|
||||
json_string = json.dumps(result.data)
|
||||
assert "Test" in json_string
|
||||
assert "plain" in json_string
|
||||
|
||||
def test_pydantic_model_with_control_chars_in_dict(self):
|
||||
"""Test Pydantic model with control chars when nested in dict."""
|
||||
model = SamplePydanticModel(
|
||||
name="Test\x00User", # Has null byte
|
||||
age=30,
|
||||
metadata={"info": "data\x08with\x0Ccontrols"},
|
||||
)
|
||||
|
||||
data = {"credential": model}
|
||||
|
||||
result = SafeJson(data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Verify control characters are removed
|
||||
import json
|
||||
|
||||
json_string = json.dumps(result.data)
|
||||
assert "\x00" not in json_string
|
||||
assert "\x08" not in json_string
|
||||
assert "\x0C" not in json_string
|
||||
assert "TestUser" in json_string # Name preserved minus null byte
|
||||
|
||||
def test_deeply_nested_pydantic_models_control_char_sanitization(self):
|
||||
"""Test that control characters are sanitized in deeply nested Pydantic models."""
|
||||
|
||||
# Create nested Pydantic models with control characters at different levels
|
||||
class InnerModel(BaseModel):
|
||||
deep_string: str
|
||||
value: int = 42
|
||||
metadata: dict = {}
|
||||
|
||||
class MiddleModel(BaseModel):
|
||||
middle_string: str
|
||||
inner: InnerModel
|
||||
data: str
|
||||
|
||||
class OuterModel(BaseModel):
|
||||
outer_string: str
|
||||
middle: MiddleModel
|
||||
|
||||
# Create test data with control characters at every nesting level
|
||||
inner = InnerModel(
|
||||
deep_string="Deepest\x00Level\x08Control\x0CChars", # Multiple control chars at deepest level
|
||||
metadata={
|
||||
"nested_key": "Nested\x1FValue\x7FDelete"
|
||||
}, # Control chars in nested dict
|
||||
)
|
||||
|
||||
middle = MiddleModel(
|
||||
middle_string="Middle\x01StartOfHeading\x1FUnitSeparator",
|
||||
inner=inner,
|
||||
data="Some\x0BVerticalTab\x0EShiftOut",
|
||||
)
|
||||
|
||||
outer = OuterModel(outer_string="Outer\x00Null\x07Bell", middle=middle)
|
||||
|
||||
# Wrap in a dict with additional control characters
|
||||
data = {
|
||||
"top_level": "Top\x00Level\x08Backspace",
|
||||
"nested_model": outer,
|
||||
"list_with_strings": [
|
||||
"List\x00Item1",
|
||||
"List\x0CItem2\x1F",
|
||||
{"dict_in_list": "Dict\x08Value"},
|
||||
],
|
||||
}
|
||||
|
||||
# Process with SafeJson
|
||||
result = SafeJson(data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Verify all control characters are removed at every level
|
||||
import json
|
||||
|
||||
json_string = json.dumps(result.data)
|
||||
|
||||
# Check that NO control characters remain anywhere
|
||||
control_chars = [
|
||||
"\x00",
|
||||
"\x01",
|
||||
"\x02",
|
||||
"\x03",
|
||||
"\x04",
|
||||
"\x05",
|
||||
"\x06",
|
||||
"\x07",
|
||||
"\x08",
|
||||
"\x0B",
|
||||
"\x0C",
|
||||
"\x0E",
|
||||
"\x0F",
|
||||
"\x10",
|
||||
"\x11",
|
||||
"\x12",
|
||||
"\x13",
|
||||
"\x14",
|
||||
"\x15",
|
||||
"\x16",
|
||||
"\x17",
|
||||
"\x18",
|
||||
"\x19",
|
||||
"\x1A",
|
||||
"\x1B",
|
||||
"\x1C",
|
||||
"\x1D",
|
||||
"\x1E",
|
||||
"\x1F",
|
||||
"\x7F",
|
||||
]
|
||||
|
||||
for char in control_chars:
|
||||
assert (
|
||||
char not in json_string
|
||||
), f"Control character {repr(char)} found in result"
|
||||
|
||||
# Verify specific sanitized content is present (control chars removed but text preserved)
|
||||
result_data = cast(dict[str, Any], result.data)
|
||||
|
||||
# Top level
|
||||
assert "TopLevelBackspace" in json_string
|
||||
|
||||
# Outer model level
|
||||
assert "OuterNullBell" in json_string
|
||||
|
||||
# Middle model level
|
||||
assert "MiddleStartOfHeadingUnitSeparator" in json_string
|
||||
assert "SomeVerticalTabShiftOut" in json_string
|
||||
|
||||
# Inner model level (deepest nesting)
|
||||
assert "DeepestLevelControlChars" in json_string
|
||||
|
||||
# Nested dict in model
|
||||
assert "NestedValueDelete" in json_string
|
||||
|
||||
# List items
|
||||
assert "ListItem1" in json_string
|
||||
assert "ListItem2" in json_string
|
||||
assert "DictValue" in json_string
|
||||
|
||||
# Verify structure is preserved (not just converted to string)
|
||||
assert isinstance(result_data, dict)
|
||||
assert isinstance(result_data["nested_model"], dict)
|
||||
assert isinstance(result_data["nested_model"]["middle"], dict)
|
||||
assert isinstance(result_data["nested_model"]["middle"]["inner"], dict)
|
||||
assert isinstance(result_data["list_with_strings"], list)
|
||||
|
||||
# Verify specific deep values are accessible and sanitized
|
||||
nested_model = cast(dict[str, Any], result_data["nested_model"])
|
||||
middle = cast(dict[str, Any], nested_model["middle"])
|
||||
inner = cast(dict[str, Any], middle["inner"])
|
||||
|
||||
deep_string = inner["deep_string"]
|
||||
assert deep_string == "DeepestLevelControlChars"
|
||||
|
||||
metadata = cast(dict[str, Any], inner["metadata"])
|
||||
nested_metadata = metadata["nested_key"]
|
||||
assert nested_metadata == "NestedValueDelete"
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import bleach
|
||||
from bleach.css_sanitizer import CSSSanitizer
|
||||
from jinja2 import BaseLoader
|
||||
from jinja2.exceptions import TemplateError
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from markupsafe import Markup
|
||||
|
||||
@@ -101,8 +102,11 @@ class TextFormatter:
|
||||
|
||||
def format_string(self, template_str: str, values=None, **kwargs) -> str:
|
||||
"""Regular template rendering with escaping"""
|
||||
template = self.env.from_string(template_str)
|
||||
return template.render(values or {}, **kwargs)
|
||||
try:
|
||||
template = self.env.from_string(template_str)
|
||||
return template.render(values or {}, **kwargs)
|
||||
except TemplateError as e:
|
||||
raise ValueError(e) from e
|
||||
|
||||
def format_email(
|
||||
self,
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
-- Create UserBalance table for atomic credit operations
|
||||
-- This replaces the need for User.balance column and provides better separation of concerns
|
||||
-- UserBalance records are automatically created by the application when users interact with the credit system
|
||||
|
||||
-- CreateTable (only if it doesn't exist)
|
||||
CREATE TABLE IF NOT EXISTS "UserBalance" (
|
||||
"userId" TEXT NOT NULL,
|
||||
"balance" INTEGER NOT NULL DEFAULT 0,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "UserBalance_pkey" PRIMARY KEY ("userId"),
|
||||
CONSTRAINT "UserBalance_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
-- CreateIndex (only if it doesn't exist)
|
||||
CREATE INDEX IF NOT EXISTS "UserBalance_userId_idx" ON "UserBalance"("userId");
|
||||
@@ -0,0 +1,100 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "StoreListingVersion" ADD COLUMN "search" tsvector DEFAULT ''::tsvector;
|
||||
|
||||
-- Add trigger to update the search column with the tsvector of the agent
|
||||
-- Function to be invoked by trigger
|
||||
|
||||
-- Drop the trigger first
|
||||
DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion";
|
||||
|
||||
-- Drop the function completely
|
||||
DROP FUNCTION IF EXISTS update_tsvector_column();
|
||||
|
||||
-- Now recreate it fresh
|
||||
CREATE OR REPLACE FUNCTION update_tsvector_column() RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.search := to_tsvector('english',
|
||||
COALESCE(NEW.name, '') || ' ' ||
|
||||
COALESCE(NEW.description, '') || ' ' ||
|
||||
COALESCE(NEW."subHeading", '')
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql SECURITY DEFINER SET search_path = platform, pg_temp;
|
||||
|
||||
-- Recreate the trigger
|
||||
CREATE TRIGGER "update_tsvector"
|
||||
BEFORE INSERT OR UPDATE ON "StoreListingVersion"
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_tsvector_column();
|
||||
|
||||
UPDATE "StoreListingVersion"
|
||||
SET search = to_tsvector('english',
|
||||
COALESCE(name, '') || ' ' ||
|
||||
COALESCE(description, '') || ' ' ||
|
||||
COALESCE("subHeading", '')
|
||||
)
|
||||
WHERE search IS NULL;
|
||||
|
||||
-- Drop and recreate the StoreAgent view with isAvailable field
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
|
||||
CREATE OR REPLACE VIEW "StoreAgent" AS
|
||||
WITH latest_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
MAX(version) AS max_version
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
),
|
||||
agent_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
slv.id AS "storeListingVersionId",
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
p.username AS creator_username, -- Allow NULL for malformed sub-agents
|
||||
p."avatarUrl" AS creator_avatar, -- Allow NULL for malformed sub-agents
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
slv.search,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
|
||||
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
|
||||
COALESCE(sl."useForOnboarding", false) AS "useForOnboarding",
|
||||
slv."isAvailable" AS is_available -- Add isAvailable field to filter sub-agents
|
||||
FROM "StoreListing" sl
|
||||
JOIN latest_versions lv
|
||||
ON sl.id = lv."storeListingId"
|
||||
JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = lv."storeListingId"
|
||||
AND slv.version = lv.max_version
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
JOIN "AgentGraph" a
|
||||
ON slv."agentGraphId" = a.id
|
||||
AND slv."agentGraphVersion" = a.version
|
||||
LEFT JOIN "Profile" p
|
||||
ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN "mv_review_stats" rs
|
||||
ON sl.id = rs."storeListingId"
|
||||
LEFT JOIN "mv_agent_run_counts" ar
|
||||
ON a.id = ar."agentGraphId"
|
||||
LEFT JOIN agent_versions av
|
||||
ON sl.id = av."storeListingId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true;
|
||||
|
||||
COMMIT;
|
||||
@@ -0,0 +1,21 @@
|
||||
-- Migrate Claude 3.5 models to Claude 4.5 models
|
||||
-- This updates all AgentNode blocks that use deprecated Claude 3.5 models to the new 4.5 models
|
||||
-- See: https://docs.anthropic.com/en/docs/about-claude/models/legacy-model-guide
|
||||
|
||||
-- Update Claude 3.5 Sonnet to Claude 4.5 Sonnet
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'claude-3-5-sonnet-latest';
|
||||
|
||||
-- Update Claude 3.5 Haiku to Claude 4.5 Haiku
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"claude-haiku-4-5-20251001"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'claude-3-5-haiku-latest';
|
||||
@@ -5,10 +5,11 @@ datasource db {
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views"]
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views", "fullTextSearch"]
|
||||
partial_type_generator = "backend/data/partial_types.py"
|
||||
}
|
||||
|
||||
// User model to mirror Auth provider users
|
||||
@@ -45,6 +46,7 @@ model User {
|
||||
AnalyticsDetails AnalyticsDetails[]
|
||||
AnalyticsMetrics AnalyticsMetrics[]
|
||||
CreditTransactions CreditTransaction[]
|
||||
UserBalance UserBalance?
|
||||
|
||||
AgentPresets AgentPreset[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
@@ -663,6 +665,7 @@ view StoreAgent {
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
@@ -746,7 +749,7 @@ model StoreListing {
|
||||
slug String
|
||||
|
||||
// Allow this agent to be used during onboarding
|
||||
useForOnboarding Boolean @default(false)
|
||||
useForOnboarding Boolean @default(false)
|
||||
|
||||
// The currently active version that should be shown to users
|
||||
activeVersionId String? @unique
|
||||
@@ -797,6 +800,8 @@ model StoreListingVersion {
|
||||
// Old versions can be made unavailable by the author if desired
|
||||
isAvailable Boolean @default(true)
|
||||
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
|
||||
// Version workflow state
|
||||
submissionStatus SubmissionStatus @default(DRAFT)
|
||||
submittedAt DateTime?
|
||||
@@ -887,6 +892,16 @@ model APIKey {
|
||||
@@index([userId, status])
|
||||
}
|
||||
|
||||
model UserBalance {
|
||||
userId String @id
|
||||
balance Int @default(0)
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
enum APIKeyStatus {
|
||||
ACTIVE
|
||||
REVOKED
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"email": "test@example.com",
|
||||
"id": "test-user-id",
|
||||
"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"name": "Test User"
|
||||
}
|
||||
@@ -28,6 +28,6 @@
|
||||
"recommended_schedule_cron": null,
|
||||
"sub_graphs": [],
|
||||
"trigger_setup_info": null,
|
||||
"user_id": "test-user-id",
|
||||
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"version": 1
|
||||
}
|
||||
@@ -26,7 +26,7 @@
|
||||
"recommended_schedule_cron": null,
|
||||
"sub_graphs": [],
|
||||
"trigger_setup_info": null,
|
||||
"user_id": "test-user-id",
|
||||
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"version": 1
|
||||
}
|
||||
]
|
||||
140
autogpt_platform/backend/test/blocks/test_youtube.py
Normal file
140
autogpt_platform/backend/test/blocks/test_youtube.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from youtube_transcript_api._errors import NoTranscriptFound
|
||||
from youtube_transcript_api._transcripts import FetchedTranscript, Transcript
|
||||
|
||||
from backend.blocks.youtube import TranscribeYoutubeVideoBlock
|
||||
|
||||
|
||||
class TestTranscribeYoutubeVideoBlock:
|
||||
"""Test cases for TranscribeYoutubeVideoBlock language fallback functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.youtube_block = TranscribeYoutubeVideoBlock()
|
||||
|
||||
def test_extract_video_id_standard_url(self):
|
||||
"""Test extracting video ID from standard YouTube URL."""
|
||||
url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
|
||||
video_id = self.youtube_block.extract_video_id(url)
|
||||
assert video_id == "dQw4w9WgXcQ"
|
||||
|
||||
def test_extract_video_id_short_url(self):
|
||||
"""Test extracting video ID from shortened youtu.be URL."""
|
||||
url = "https://youtu.be/dQw4w9WgXcQ"
|
||||
video_id = self.youtube_block.extract_video_id(url)
|
||||
assert video_id == "dQw4w9WgXcQ"
|
||||
|
||||
def test_extract_video_id_embed_url(self):
|
||||
"""Test extracting video ID from embed URL."""
|
||||
url = "https://www.youtube.com/embed/dQw4w9WgXcQ"
|
||||
video_id = self.youtube_block.extract_video_id(url)
|
||||
assert video_id == "dQw4w9WgXcQ"
|
||||
|
||||
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
|
||||
def test_get_transcript_english_available(self, mock_api_class):
|
||||
"""Test getting transcript when English is available."""
|
||||
# Setup mock
|
||||
mock_api = Mock()
|
||||
mock_api_class.return_value = mock_api
|
||||
mock_transcript = Mock(spec=FetchedTranscript)
|
||||
mock_api.fetch.return_value = mock_transcript
|
||||
|
||||
# Execute
|
||||
result = TranscribeYoutubeVideoBlock.get_transcript("test_video_id")
|
||||
|
||||
# Assert
|
||||
assert result == mock_transcript
|
||||
mock_api.fetch.assert_called_once_with(video_id="test_video_id")
|
||||
mock_api.list.assert_not_called()
|
||||
|
||||
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
|
||||
def test_get_transcript_fallback_to_first_available(self, mock_api_class):
|
||||
"""Test fallback to first available language when English is not available."""
|
||||
# Setup mock
|
||||
mock_api = Mock()
|
||||
mock_api_class.return_value = mock_api
|
||||
|
||||
# Create mock transcript list with Hungarian transcript
|
||||
mock_transcript_list = Mock()
|
||||
mock_transcript_hu = Mock(spec=Transcript)
|
||||
mock_fetched_transcript = Mock(spec=FetchedTranscript)
|
||||
mock_transcript_hu.fetch.return_value = mock_fetched_transcript
|
||||
|
||||
# Set up the transcript list to have manually created transcripts empty
|
||||
# and generated transcripts with Hungarian
|
||||
mock_transcript_list._manually_created_transcripts = {}
|
||||
mock_transcript_list._generated_transcripts = {"hu": mock_transcript_hu}
|
||||
|
||||
# Mock API to raise NoTranscriptFound for English, then return list
|
||||
mock_api.fetch.side_effect = NoTranscriptFound(
|
||||
"test_video_id", ("en",), mock_transcript_list
|
||||
)
|
||||
mock_api.list.return_value = mock_transcript_list
|
||||
|
||||
# Execute
|
||||
result = TranscribeYoutubeVideoBlock.get_transcript("test_video_id")
|
||||
|
||||
# Assert
|
||||
assert result == mock_fetched_transcript
|
||||
mock_api.fetch.assert_called_once_with(video_id="test_video_id")
|
||||
mock_api.list.assert_called_once_with("test_video_id")
|
||||
mock_transcript_hu.fetch.assert_called_once()
|
||||
|
||||
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
|
||||
def test_get_transcript_prefers_manually_created(self, mock_api_class):
|
||||
"""Test that manually created transcripts are preferred over generated ones."""
|
||||
# Setup mock
|
||||
mock_api = Mock()
|
||||
mock_api_class.return_value = mock_api
|
||||
|
||||
# Create mock transcript list with both manual and generated transcripts
|
||||
mock_transcript_list = Mock()
|
||||
mock_transcript_manual = Mock(spec=Transcript)
|
||||
mock_transcript_generated = Mock(spec=Transcript)
|
||||
mock_fetched_manual = Mock(spec=FetchedTranscript)
|
||||
mock_transcript_manual.fetch.return_value = mock_fetched_manual
|
||||
|
||||
# Set up the transcript list
|
||||
mock_transcript_list._manually_created_transcripts = {
|
||||
"es": mock_transcript_manual
|
||||
}
|
||||
mock_transcript_list._generated_transcripts = {"hu": mock_transcript_generated}
|
||||
|
||||
# Mock API to raise NoTranscriptFound for English
|
||||
mock_api.fetch.side_effect = NoTranscriptFound(
|
||||
"test_video_id", ("en",), mock_transcript_list
|
||||
)
|
||||
mock_api.list.return_value = mock_transcript_list
|
||||
|
||||
# Execute
|
||||
result = TranscribeYoutubeVideoBlock.get_transcript("test_video_id")
|
||||
|
||||
# Assert - should use manually created transcript first
|
||||
assert result == mock_fetched_manual
|
||||
mock_transcript_manual.fetch.assert_called_once()
|
||||
mock_transcript_generated.fetch.assert_not_called()
|
||||
|
||||
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
|
||||
def test_get_transcript_no_transcripts_available(self, mock_api_class):
|
||||
"""Test that exception is re-raised when no transcripts are available at all."""
|
||||
# Setup mock
|
||||
mock_api = Mock()
|
||||
mock_api_class.return_value = mock_api
|
||||
|
||||
# Create mock transcript list with no transcripts
|
||||
mock_transcript_list = Mock()
|
||||
mock_transcript_list._manually_created_transcripts = {}
|
||||
mock_transcript_list._generated_transcripts = {}
|
||||
|
||||
# Mock API to raise NoTranscriptFound
|
||||
original_exception = NoTranscriptFound(
|
||||
"test_video_id", ("en",), mock_transcript_list
|
||||
)
|
||||
mock_api.fetch.side_effect = original_exception
|
||||
mock_api.list.return_value = mock_transcript_list
|
||||
|
||||
# Execute and assert exception is raised
|
||||
with pytest.raises(NoTranscriptFound):
|
||||
TranscribeYoutubeVideoBlock.get_transcript("test_video_id")
|
||||
@@ -749,10 +749,11 @@ class TestDataCreator:
|
||||
"""Add credits to users."""
|
||||
print("Adding credits to users...")
|
||||
|
||||
credit_model = get_user_credit_model()
|
||||
|
||||
for user in self.users:
|
||||
try:
|
||||
# Get user-specific credit model
|
||||
credit_model = await get_user_credit_model(user["id"])
|
||||
|
||||
# Skip credits for disabled credit model to avoid errors
|
||||
if (
|
||||
hasattr(credit_model, "__class__")
|
||||
|
||||
@@ -21,6 +21,7 @@ import random
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.enums
|
||||
import pytest
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from faker import Faker
|
||||
from prisma import Json, Prisma
|
||||
@@ -498,9 +499,6 @@ async def main():
|
||||
if store_listing_versions and random.random() < 0.5
|
||||
else None
|
||||
),
|
||||
"agentInput": (
|
||||
Json({"test": "data"}) if random.random() < 0.3 else None
|
||||
),
|
||||
"onboardingAgentExecutionId": (
|
||||
random.choice(agent_graph_executions).id
|
||||
if agent_graph_executions and random.random() < 0.3
|
||||
@@ -570,5 +568,11 @@ async def main():
|
||||
print("Test data creation completed successfully!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_main_function_runs_without_errors():
|
||||
await main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
NEXT_PUBLIC_LAUNCHDARKLY_ENABLED=false
|
||||
NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID=687ab1372f497809b131e06e
|
||||
|
||||
NEXT_PUBLIC_SHOW_BILLING_PAGE=false
|
||||
NEXT_PUBLIC_TURNSTILE=disabled
|
||||
NEXT_PUBLIC_REACT_QUERY_DEVTOOL=true
|
||||
|
||||
|
||||
765
autogpt_platform/frontend/CONTRIBUTING.md
Normal file
765
autogpt_platform/frontend/CONTRIBUTING.md
Normal file
@@ -0,0 +1,765 @@
|
||||
<div align="center">
|
||||
<h1>AutoGPT Frontend • Contributing ⌨️</h1>
|
||||
<p>Next.js App Router • Client-first • Type-safe generated API hooks • Tailwind + shadcn/ui</p>
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## ☕️ Summary
|
||||
|
||||
This document is your reference for contributing to the AutoGPT Frontend. It adapts legacy guidelines to our current stack and practices.
|
||||
|
||||
- Architecture and stack
|
||||
- Component structure and design system
|
||||
- Data fetching (generated API hooks)
|
||||
- Feature flags
|
||||
- Naming and code conventions
|
||||
- Tooling, scripts, and testing
|
||||
- PR process and checklist
|
||||
|
||||
This is a living document. Open a pull request any time to improve it.
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Quick Start FAQ
|
||||
|
||||
New to the codebase? Here are shortcuts to common tasks:
|
||||
|
||||
### I need to make a new page
|
||||
|
||||
1. Create page in `src/app/(platform)/your-feature/page.tsx`
|
||||
2. If it has logic, create `usePage.ts` hook next to it
|
||||
3. Create sub-components in `components/` folder
|
||||
4. Use generated API hooks for data fetching
|
||||
5. If page needs auth, ensure it's in the `(platform)` route group
|
||||
|
||||
**Example structure:**
|
||||
|
||||
```
|
||||
app/(platform)/dashboard/
|
||||
page.tsx
|
||||
useDashboardPage.ts
|
||||
components/
|
||||
StatsPanel/
|
||||
StatsPanel.tsx
|
||||
useStatsPanel.ts
|
||||
```
|
||||
|
||||
See [Component structure](#-component-structure) and [Styling](#-styling) and [Data fetching patterns](#-data-fetching-patterns) sections.
|
||||
|
||||
### I need to update an existing component in a page
|
||||
|
||||
1. Find the page `src/app/(platform)/your-feature/page.tsx`
|
||||
2. Check its `components/` folder
|
||||
3. If needing to update its logic, check the `use[Component].ts` hook
|
||||
4. If the update is related to rendering, check `[Component].tsx` file
|
||||
|
||||
See [Component structure](#-component-structure) and [Styling](#-styling) sections.
|
||||
|
||||
### I need to make a new API call and show it on the UI
|
||||
|
||||
1. Ensure the backend endpoint exists in the OpenAPI spec
|
||||
2. Regenerate API client: `pnpm generate:api`
|
||||
3. Import the generated hook by typing the operation name (auto-import)
|
||||
4. Use the hook in your component/custom hook
|
||||
5. Handle loading, error, and success states
|
||||
|
||||
**Example:**
|
||||
|
||||
```tsx
|
||||
import { useGetV2ListLibraryAgents } from "@/app/api/__generated__/endpoints/library/library";
|
||||
|
||||
export function useAgentList() {
|
||||
const { data, isLoading, isError, error } = useGetV2ListLibraryAgents();
|
||||
|
||||
return {
|
||||
agents: data?.data || [],
|
||||
isLoading,
|
||||
isError,
|
||||
error,
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
See [Data fetching patterns](#-data-fetching-patterns) for more examples.
|
||||
|
||||
### I need to create a new component in the Design System
|
||||
|
||||
1. Determine the atomic level: atom, molecule, or organism
|
||||
2. Create folder: `src/components/[level]/ComponentName/`
|
||||
3. Create `ComponentName.tsx` (render logic)
|
||||
4. If logic exists, create `useComponentName.ts`
|
||||
5. Create `ComponentName.stories.tsx` for Storybook
|
||||
6. Use Tailwind + design tokens (avoid hardcoded values)
|
||||
7. Only use Phosphor icons
|
||||
8. Test in Storybook: `pnpm storybook`
|
||||
9. Verify in Chromatic after PR
|
||||
|
||||
**Example structure:**
|
||||
|
||||
```
|
||||
src/components/molecules/DataCard/
|
||||
DataCard.tsx
|
||||
DataCard.stories.tsx
|
||||
useDataCard.ts
|
||||
```
|
||||
|
||||
See [Component structure](#-component-structure) and [Styling](#-styling) sections.
|
||||
|
||||
---
|
||||
|
||||
## 📟 Contribution process
|
||||
|
||||
### 1) Branch off `dev`
|
||||
|
||||
- Branch from `dev` for features and fixes
|
||||
- Keep PRs focused (aim for one ticket per PR)
|
||||
- Use conventional commit messages with a scope (e.g., `feat(frontend): add X`)
|
||||
|
||||
### 2) Feature flags
|
||||
|
||||
If a feature will ship across multiple PRs, guard it with a flag so we can merge iteratively.
|
||||
|
||||
- Use [LaunchDarkly](https://www.launchdarkly.com) based flags (see Feature Flags below)
|
||||
- Avoid long-lived feature branches
|
||||
|
||||
### 3) Open PR and get reviews ✅
|
||||
|
||||
Before requesting review:
|
||||
|
||||
- [x] Code follows architecture and conventions here
|
||||
- [x] `pnpm format && pnpm lint && pnpm types` pass
|
||||
- [x] Relevant tests pass locally: `pnpm test` (and/or Storybook tests)
|
||||
- [x] If touching UI, validate against our design system and stories
|
||||
|
||||
### 4) Merge to `dev`
|
||||
|
||||
- Use squash merges
|
||||
- Follow conventional commit message format for the squash title
|
||||
|
||||
---
|
||||
|
||||
## 📂 Architecture & Stack
|
||||
|
||||
### Next.js App Router
|
||||
|
||||
- We use the [Next.js App Router](https://nextjs.org/docs/app) in `src/app`
|
||||
- Use [route segments](https://nextjs.org/docs/app/building-your-application/routing) with semantic URLs; no `pages/`
|
||||
|
||||
### Component good practices
|
||||
|
||||
- Default to client components
|
||||
- Use server components only when:
|
||||
- SEO requires server-rendered HTML, or
|
||||
- Extreme first-byte performance justifies it
|
||||
- If you render server-side data, prefer server-side prefetch + client hydration (see examples below and [React Query SSR & Hydration](https://tanstack.com/query/latest/docs/framework/react/guides/ssr))
|
||||
- Prefer using [Next.js API routes](https://nextjs.org/docs/pages/building-your-application/routing/api-routes) when possible over [server actions](https://nextjs.org/docs/14/app/building-your-application/data-fetching/server-actions-and-mutations)
|
||||
- Keep components small and simple
|
||||
- favour composition and splitting large components into smaller bits of UI
|
||||
- [colocate state](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster) when possible
|
||||
- keep render/side-effects split for [separation of concerns](https://en.wikipedia.org/wiki/Separation_of_concerns)
|
||||
- do not over-complicate or re-invent the wheel
|
||||
|
||||
**❓ Why a client-side first design vs server components/actions?**
|
||||
|
||||
While server components and actions are cool and cutting-edge, they introduce a layer of complexity which not always justified by the benefits they deliver. Defaulting to client-first keeps things simple in the mental model of the developer, specially for those developers less familiar with Next.js or heavy Front-end development.
|
||||
|
||||
### Data fetching: prefer generated API hooks
|
||||
|
||||
- We generate a type-safe client and React Query hooks from the backend OpenAPI spec via [Orval](https://orval.dev/)
|
||||
- Prefer the generated hooks under `src/app/api/__generated__/endpoints/...`
|
||||
- Treat `BackendAPI` and code under `src/lib/autogpt-server-api/*` as deprecated; do not introduce new usages
|
||||
- Use [Zod](https://zod.dev/) schemas from the generated client where applicable
|
||||
|
||||
### State management
|
||||
|
||||
- Prefer [React Query](https://tanstack.com/query/latest/docs/framework/react/overview) for server state, colocated near consumers (see [state colocation](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster))
|
||||
- Co-locate UI state inside components/hooks; keep global state minimal
|
||||
|
||||
### Styling and components
|
||||
|
||||
- [Tailwind CSS](https://tailwindcss.com/docs) + [shadcn/ui](https://ui.shadcn.com/) ([Radix Primitives](https://www.radix-ui.com/docs/primitives/overview/introduction) under the hood)
|
||||
- Use the design system under `src/components` for primitives and building blocks
|
||||
- Do not use anything under `src/components/_legacy__`; migrate away from it when touching old code
|
||||
- Reference the design system catalog on Chromatic: [`https://dev--670f94474adee5e32c896b98.chromatic.com/`](https://dev--670f94474adee5e32c896b98.chromatic.com/)
|
||||
- Use the [`tailwind-scrollbar`](https://www.npmjs.com/package/tailwind-scrollbar) plugin utilities for scrollbar styling
|
||||
|
||||
---
|
||||
|
||||
## 🧱 Component structure
|
||||
|
||||
For components, separate render logic from data/behavior, and keep implementation details local.
|
||||
|
||||
**Most components should follow this structure.** Pages are just bigger components made of smaller ones, and sub-components can have their own nested sub-components when dealing with complex features.
|
||||
|
||||
### Basic structure
|
||||
|
||||
When a component has non-trivial logic:
|
||||
|
||||
```
|
||||
FeatureX/
|
||||
FeatureX.tsx (render logic only)
|
||||
useFeatureX.ts (hook; data fetching, behavior, state)
|
||||
helpers.ts (pure helpers used by the hook)
|
||||
components/ (optional, subcomponents local to FeatureX)
|
||||
```
|
||||
|
||||
### Example: Page with nested components
|
||||
|
||||
```tsx
|
||||
// Page composition
|
||||
app/(platform)/dashboard/
|
||||
page.tsx
|
||||
useDashboardPage.ts
|
||||
components/ # (Sub-components the dashboard page is made of)
|
||||
StatsPanel/
|
||||
StatsPanel.tsx
|
||||
useStatsPanel.ts
|
||||
helpers.ts
|
||||
components/ # (Sub-components belonging to StatsPanel)
|
||||
StatCard/
|
||||
StatCard.tsx
|
||||
ActivityFeed/
|
||||
ActivityFeed.tsx
|
||||
useActivityFeed.ts
|
||||
```
|
||||
|
||||
### Guidelines
|
||||
|
||||
- Prefer function declarations for components and handlers
|
||||
- Only use arrow functions for small inline lambdas (e.g., in `map`)
|
||||
- Avoid barrel files and `index.ts` re-exports
|
||||
- Keep component files focused and readable; push complex logic to `helpers.ts`
|
||||
- Abstract reusable, cross-feature logic into `src/services/` or `src/lib/utils.ts` as appropriate
|
||||
- Build components encapsulated so they can be easily reused and abstracted elsewhere
|
||||
- Nest sub-components within a `components/` folder when they're local to the parent feature
|
||||
|
||||
### Exceptions
|
||||
|
||||
When to simplify the structure:
|
||||
|
||||
**Small hook logic (3-4 lines)**
|
||||
|
||||
If the hook logic is minimal, keep it inline with the render function:
|
||||
|
||||
```tsx
|
||||
export function ActivityAlert() {
|
||||
const [isVisible, setIsVisible] = useState(true);
|
||||
if (!isVisible) return null;
|
||||
|
||||
return (
|
||||
<Alert onClose={() => setIsVisible(false)}>New activity detected</Alert>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
**Render-only components**
|
||||
|
||||
Components with no hook logic can be direct files in `components/` without a folder:
|
||||
|
||||
```
|
||||
components/
|
||||
ActivityAlert.tsx (render-only, no folder needed)
|
||||
StatsPanel/ (has hook logic, needs folder)
|
||||
StatsPanel.tsx
|
||||
useStatsPanel.ts
|
||||
```
|
||||
|
||||
### Hook file structure
|
||||
|
||||
When separating logic into a custom hook:
|
||||
|
||||
```tsx
|
||||
// useStatsPanel.ts
|
||||
export function useStatsPanel() {
|
||||
const [data, setData] = useState<Stats[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
fetchStats().then(setData);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
data,
|
||||
isLoading,
|
||||
refresh: () => fetchStats().then(setData),
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- **Always return an object** that exposes data and methods to the view
|
||||
- **Export a single function** named after the component (e.g., `useStatsPanel` for `StatsPanel.tsx`)
|
||||
- **Abstract into helpers.ts** when hook logic grows large, so the hook file remains readable by scanning without diving into implementation details
|
||||
|
||||
---
|
||||
|
||||
## 🔄 Data fetching patterns
|
||||
|
||||
All API hooks are generated from the backend OpenAPI specification using [Orval](https://orval.dev/). The hooks are type-safe and follow the operation names defined in the backend API.
|
||||
|
||||
### How to discover hooks
|
||||
|
||||
Most of the time you can rely on auto-import by typing the endpoint or operation name. Your IDE will suggest the generated hooks based on the OpenAPI operation IDs.
|
||||
|
||||
**Examples of hook naming patterns:**
|
||||
|
||||
- `GET /api/v1/notifications` → `useGetV1GetNotificationPreferences`
|
||||
- `POST /api/v2/store/agents` → `usePostV2CreateStoreAgent`
|
||||
- `DELETE /api/v2/store/submissions/{id}` → `useDeleteV2DeleteStoreSubmission`
|
||||
- `GET /api/v2/library/agents` → `useGetV2ListLibraryAgents`
|
||||
|
||||
**Pattern**: `use{Method}{Version}{OperationName}`
|
||||
|
||||
You can also explore the generated hooks by browsing `src/app/api/__generated__/endpoints/` which is organized by API tags (e.g., `auth`, `store`, `library`).
|
||||
|
||||
**OpenAPI specs:**
|
||||
|
||||
- Production: [https://backend.agpt.co/openapi.json](https://backend.agpt.co/openapi.json)
|
||||
- Staging: [https://dev-server.agpt.co/openapi.json](https://dev-server.agpt.co/openapi.json)
|
||||
|
||||
### Generated hooks (client)
|
||||
|
||||
Prefer the generated React Query hooks (via Orval + React Query):
|
||||
|
||||
```tsx
|
||||
import { useGetV1GetNotificationPreferences } from "@/app/api/__generated__/endpoints/auth/auth";
|
||||
|
||||
export function PreferencesPanel() {
|
||||
const { data, isLoading, isError } = useGetV1GetNotificationPreferences({
|
||||
query: {
|
||||
select: (res) => res.data,
|
||||
},
|
||||
});
|
||||
|
||||
if (isLoading) return null;
|
||||
if (isError) throw new Error("Failed to load preferences");
|
||||
return <pre>{JSON.stringify(data, null, 2)}</pre>;
|
||||
}
|
||||
```
|
||||
|
||||
### Generated mutations (client)
|
||||
|
||||
```tsx
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import {
|
||||
useDeleteV2DeleteStoreSubmission,
|
||||
getGetV2ListMySubmissionsQueryKey,
|
||||
} from "@/app/api/__generated__/endpoints/store/store";
|
||||
|
||||
export function DeleteSubmissionButton({
|
||||
submissionId,
|
||||
}: {
|
||||
submissionId: string;
|
||||
}) {
|
||||
const queryClient = useQueryClient();
|
||||
const { mutateAsync: deleteSubmission, isPending } =
|
||||
useDeleteV2DeleteStoreSubmission({
|
||||
mutation: {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListMySubmissionsQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
async function onClick() {
|
||||
await deleteSubmission({ submissionId });
|
||||
}
|
||||
|
||||
return (
|
||||
<button disabled={isPending} onClick={onClick}>
|
||||
Delete
|
||||
</button>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### Server-side prefetch + client hydration
|
||||
|
||||
Use server-side prefetch to improve TTFB while keeping the component tree client-first (see [React Query SSR & Hydration](https://tanstack.com/query/latest/docs/framework/react/guides/ssr)):
|
||||
|
||||
```tsx
|
||||
// in a server component
|
||||
import { getQueryClient } from "@/lib/tanstack-query/getQueryClient";
|
||||
import { HydrationBoundary, dehydrate } from "@tanstack/react-query";
|
||||
import {
|
||||
prefetchGetV2ListStoreAgentsQuery,
|
||||
prefetchGetV2ListStoreCreatorsQuery,
|
||||
} from "@/app/api/__generated__/endpoints/store/store";
|
||||
|
||||
export default async function MarketplacePage() {
|
||||
const queryClient = getQueryClient();
|
||||
|
||||
await Promise.all([
|
||||
prefetchGetV2ListStoreAgentsQuery(queryClient, { featured: true }),
|
||||
prefetchGetV2ListStoreAgentsQuery(queryClient, { sorted_by: "runs" }),
|
||||
prefetchGetV2ListStoreCreatorsQuery(queryClient, {
|
||||
featured: true,
|
||||
sorted_by: "num_agents",
|
||||
}),
|
||||
]);
|
||||
|
||||
return (
|
||||
<HydrationBoundary state={dehydrate(queryClient)}>
|
||||
{/* Client component tree goes here */}
|
||||
</HydrationBoundary>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- Do not introduce new usages of `BackendAPI` or `src/lib/autogpt-server-api/*`
|
||||
- Keep transformations and mapping logic close to the consumer (hook), not in the view
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ Error handling
|
||||
|
||||
The app has multiple error handling strategies depending on the type of error:
|
||||
|
||||
### Render/runtime errors
|
||||
|
||||
Use `<ErrorCard />` to display render or runtime errors gracefully:
|
||||
|
||||
```tsx
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard";
|
||||
|
||||
export function DataPanel() {
|
||||
const { data, isLoading, isError, error } = useGetData();
|
||||
|
||||
if (isLoading) return <Skeleton />;
|
||||
if (isError) return <ErrorCard error={error} />;
|
||||
|
||||
return <div>{data.content}</div>;
|
||||
}
|
||||
```
|
||||
|
||||
### API mutation errors
|
||||
|
||||
Display mutation errors using toast notifications:
|
||||
|
||||
```tsx
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
|
||||
export function useUpdateSettings() {
|
||||
const { toast } = useToast();
|
||||
const { mutateAsync: updateSettings } = useUpdateSettingsMutation({
|
||||
mutation: {
|
||||
onError: (error) => {
|
||||
toast({
|
||||
title: "Failed to update settings",
|
||||
description: error.message,
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return { updateSettings };
|
||||
}
|
||||
```
|
||||
|
||||
### Manual Sentry capture
|
||||
|
||||
When needed, you can manually capture exceptions to Sentry:
|
||||
|
||||
```tsx
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
try {
|
||||
await riskyOperation();
|
||||
} catch (error) {
|
||||
Sentry.captureException(error, {
|
||||
tags: { context: "feature-x" },
|
||||
extra: { metadata: additionalData },
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
```
|
||||
|
||||
### Global error boundaries
|
||||
|
||||
The app has error boundaries already configured to:
|
||||
|
||||
- Capture uncaught errors globally and send them to Sentry
|
||||
- Display a user-friendly error UI when something breaks
|
||||
- Prevent the entire app from crashing
|
||||
|
||||
You don't need to wrap components in error boundaries manually unless you need custom error recovery logic.
|
||||
|
||||
---
|
||||
|
||||
## 🚩 Feature Flags
|
||||
|
||||
- Flags are powered by [LaunchDarkly](https://docs.launchdarkly.com/)
|
||||
- Use the helper APIs under `src/services/feature-flags`
|
||||
|
||||
Check a flag in a client component:
|
||||
|
||||
```tsx
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
export function AgentActivityPanel() {
|
||||
const enabled = useGetFlag(Flag.AGENT_ACTIVITY);
|
||||
if (!enabled) return null;
|
||||
return <div>Feature is enabled!</div>;
|
||||
}
|
||||
```
|
||||
|
||||
Protect a route or page component:
|
||||
|
||||
```tsx
|
||||
import { withFeatureFlag } from "@/services/feature-flags/with-feature-flag";
|
||||
|
||||
export const MyFeaturePage = withFeatureFlag(function Page() {
|
||||
return <div>My feature page</div>;
|
||||
}, "my-feature-flag");
|
||||
```
|
||||
|
||||
Local dev and Playwright:
|
||||
|
||||
- Set `NEXT_PUBLIC_PW_TEST=true` to use mocked flag values during local development and tests
|
||||
|
||||
Adding new flags:
|
||||
|
||||
1. Add the flag to the `Flag` enum and `FlagValues` type
|
||||
2. Provide a mock value in the mock map
|
||||
3. Configure the flag in LaunchDarkly
|
||||
|
||||
---
|
||||
|
||||
## 📙 Naming conventions
|
||||
|
||||
General:
|
||||
|
||||
- Variables and functions should read like plain English
|
||||
- Prefer `const` over `let` unless reassignment is required
|
||||
- Use searchable constants instead of magic numbers
|
||||
|
||||
Files:
|
||||
|
||||
- Components and hooks: `PascalCase` for component files, `camelCase` for hooks
|
||||
- Other files: `kebab-case`
|
||||
- Do not create barrel files or `index.ts` re-exports
|
||||
|
||||
Types:
|
||||
|
||||
- Prefer `interface` for object shapes
|
||||
- Component props should be `interface Props { ... }`
|
||||
- Use precise types; avoid `any` and unsafe casts
|
||||
|
||||
Parameters:
|
||||
|
||||
- If more than one parameter is needed, pass a single `Args` object for clarity
|
||||
|
||||
Comments:
|
||||
|
||||
- Keep comments minimal; code should be clear by itself
|
||||
- Only document non-obvious intent, invariants, or caveats
|
||||
|
||||
Functions:
|
||||
|
||||
- Prefer function declarations for components and handlers
|
||||
- Only use arrow functions for small inline callbacks
|
||||
|
||||
Control flow:
|
||||
|
||||
- Use early returns to reduce nesting
|
||||
- Avoid catching errors unless you handle them meaningfully
|
||||
|
||||
---
|
||||
|
||||
## 🎨 Styling
|
||||
|
||||
- Use Tailwind utilities; prefer semantic, composable class names
|
||||
- Use shadcn/ui components as building blocks when available
|
||||
- Use the `tailwind-scrollbar` utilities for scrollbar styling
|
||||
- Keep responsive and dark-mode behavior consistent with the design system
|
||||
|
||||
Additional requirements:
|
||||
|
||||
- Do not import shadcn primitives directly in feature code; only use components exposed in our design system under `src/components`. shadcn is a low-level skeleton we style on top of and is not meant to be consumed directly.
|
||||
- Prefer design tokens over Tailwind's default theme whenever possible (e.g., color, spacing, radius, and typography tokens). Avoid hardcoded values and default palette if a token exists.
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ Errors and ⏳ Loading
|
||||
|
||||
- **Errors**: Use the `ErrorCard` component from the design system to display API/HTTP errors and retry actions. Keep error derivation/mapping in hooks; pass the final message to the component.
|
||||
- Component: `src/components/molecules/ErrorCard/ErrorCard.tsx`
|
||||
- **Loading**: Use the `Skeleton` component(s) from the design system for loading states. Favor domain-appropriate skeleton layouts (lists, cards, tables) over spinners.
|
||||
- See Storybook examples under Atoms/Skeleton for patterns.
|
||||
|
||||
---
|
||||
|
||||
## 🧭 Responsive and mobile-first
|
||||
|
||||
- Build mobile-first. Ensure new UI looks great from a 375px viewport width (iPhone SE) upwards.
|
||||
- Validate layouts at common breakpoints (375, 768, 1024, 1280). Prefer stacking and progressive disclosure on small screens.
|
||||
|
||||
---
|
||||
|
||||
## 🧰 State for complex flows
|
||||
|
||||
For components/flows with complex state, multi-step wizards, or cross-component coordination, prefer a small co-located store using [Zustand](https://github.com/pmndrs/zustand).
|
||||
|
||||
Guidelines:
|
||||
|
||||
- Co-locate the store with the feature (e.g., `FeatureX/store.ts`).
|
||||
- Expose typed selectors to minimize re-renders.
|
||||
- Keep effects and API calls in hooks; stores hold state and pure actions.
|
||||
|
||||
Example: simple store with selectors
|
||||
|
||||
```ts
|
||||
import { create } from "zustand";
|
||||
|
||||
interface WizardState {
|
||||
step: number;
|
||||
data: Record<string, unknown>;
|
||||
next(): void;
|
||||
back(): void;
|
||||
setField(args: { key: string; value: unknown }): void;
|
||||
}
|
||||
|
||||
export const useWizardStore = create<WizardState>((set) => ({
|
||||
step: 0,
|
||||
data: {},
|
||||
next() {
|
||||
set((state) => ({ step: state.step + 1 }));
|
||||
},
|
||||
back() {
|
||||
set((state) => ({ step: Math.max(0, state.step - 1) }));
|
||||
},
|
||||
setField({ key, value }) {
|
||||
set((state) => ({ data: { ...state.data, [key]: value } }));
|
||||
},
|
||||
}));
|
||||
|
||||
// Usage in a component (selectors keep updates scoped)
|
||||
function WizardFooter() {
|
||||
const step = useWizardStore((s) => s.step);
|
||||
const next = useWizardStore((s) => s.next);
|
||||
const back = useWizardStore((s) => s.back);
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<button onClick={back} disabled={step === 0}>Back</button>
|
||||
<button onClick={next}>Next</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
Example: async action coordinated via hook + store
|
||||
|
||||
```ts
|
||||
// FeatureX/useFeatureX.ts
|
||||
import { useMutation } from "@tanstack/react-query";
|
||||
import { useWizardStore } from "./store";
|
||||
|
||||
export function useFeatureX() {
|
||||
const setField = useWizardStore((s) => s.setField);
|
||||
const next = useWizardStore((s) => s.next);
|
||||
|
||||
const { mutateAsync: save, isPending } = useMutation({
|
||||
mutationFn: async (payload: unknown) => {
|
||||
// call API here
|
||||
return payload;
|
||||
},
|
||||
onSuccess(data) {
|
||||
setField({ key: "result", value: data });
|
||||
next();
|
||||
},
|
||||
});
|
||||
|
||||
return { save, isSaving: isPending };
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🖼 Icons
|
||||
|
||||
- Only use Phosphor Icons. Treat all other icon libraries as deprecated for new code.
|
||||
- Package: `@phosphor-icons/react`
|
||||
- Site: [`https://phosphoricons.com/`](https://phosphoricons.com/)
|
||||
|
||||
Example usage:
|
||||
|
||||
```tsx
|
||||
import { Plus } from "@phosphor-icons/react";
|
||||
|
||||
export function CreateButton() {
|
||||
return (
|
||||
<button type="button" className="inline-flex items-center gap-2">
|
||||
<Plus size={16} />
|
||||
Create
|
||||
</button>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Testing & Storybook
|
||||
|
||||
- End-to-end: [Playwright](https://playwright.dev/docs/intro) (`pnpm test`, `pnpm test-ui`)
|
||||
- [Storybook](https://storybook.js.org/docs) for isolated UI development (`pnpm storybook` / `pnpm build-storybook`)
|
||||
- For Storybook tests in CI, see [`@storybook/test-runner`](https://storybook.js.org/docs/writing-tests/test-runner) (`test-storybook:ci`)
|
||||
- When changing components in `src/components`, update or add stories and visually verify in Storybook/Chromatic
|
||||
|
||||
---
|
||||
|
||||
## 🛠 Tooling & Scripts
|
||||
|
||||
Common scripts (see `package.json` for full list):
|
||||
|
||||
- `pnpm dev` — Start Next.js dev server (generates API client first)
|
||||
- `pnpm build` — Build for production
|
||||
- `pnpm start` — Start production server
|
||||
- `pnpm lint` — ESLint + Prettier check
|
||||
- `pnpm format` — Format code
|
||||
- `pnpm types` — Type-check
|
||||
- `pnpm storybook` — Run Storybook
|
||||
- `pnpm test` — Run Playwright tests
|
||||
|
||||
Generated API client:
|
||||
|
||||
- `pnpm generate:api` — Fetch OpenAPI spec and regenerate the client
|
||||
|
||||
---
|
||||
|
||||
## ✅ PR checklist (Frontend)
|
||||
|
||||
- Client-first: server components only for SEO or extreme TTFB needs
|
||||
- Uses generated API hooks; no new `BackendAPI` usages
|
||||
- UI uses `src/components` primitives; no new `_legacy__` components
|
||||
- Logic is separated into `use*.ts` and `helpers.ts` when non-trivial
|
||||
- Reusable logic extracted to `src/services/` or `src/lib/utils.ts` when appropriate
|
||||
- Navigation uses the Next.js router
|
||||
- Lint, format, type-check, and tests pass locally
|
||||
- Stories updated/added if UI changed; verified in Storybook
|
||||
|
||||
---
|
||||
|
||||
## ♻️ Migration guidance
|
||||
|
||||
When touching legacy code:
|
||||
|
||||
- Replace usages of `src/components/_legacy__/*` with the modern design system components under `src/components`
|
||||
- Replace `BackendAPI` or `src/lib/autogpt-server-api/*` with generated API hooks
|
||||
- Move presentational logic into render files and data/behavior into hooks
|
||||
- Keep one-off transformations in local `helpers.ts`; move reusable logic to `src/services/` or `src/lib/utils.ts`
|
||||
|
||||
---
|
||||
|
||||
## 📚 References
|
||||
|
||||
- Design system (Chromatic): [`https://dev--670f94474adee5e32c896b98.chromatic.com/`](https://dev--670f94474adee5e32c896b98.chromatic.com/)
|
||||
- Project README for setup and API client examples: `autogpt_platform/frontend/README.md`
|
||||
- Conventional Commits: [conventionalcommits.org](https://www.conventionalcommits.org/)
|
||||
@@ -4,20 +4,12 @@ This is the frontend for AutoGPT's next generation
|
||||
|
||||
This project uses [**pnpm**](https://pnpm.io/) as the package manager via **corepack**. [Corepack](https://github.com/nodejs/corepack) is a Node.js tool that automatically manages package managers without requiring global installations.
|
||||
|
||||
For architecture, conventions, data fetching, feature flags, design system usage, state management, and PR process, see [CONTRIBUTING.md](./CONTRIBUTING.md).
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Make sure you have Node.js 16.10+ installed. Corepack is included with Node.js by default.
|
||||
|
||||
### ⚠️ Migrating from yarn
|
||||
|
||||
> This project was previously using yarn1, make sure to clean up the old files if you set it up previously with yarn:
|
||||
>
|
||||
> ```bash
|
||||
> rm -f yarn.lock && rm -rf node_modules
|
||||
> ```
|
||||
>
|
||||
> Then follow the setup steps below.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. **Enable corepack** (run this once on your system):
|
||||
@@ -96,184 +88,13 @@ Every time a new Front-end dependency is added by you or others, you will need t
|
||||
|
||||
This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font.
|
||||
|
||||
## 🔄 Data Fetching Strategy
|
||||
## 🔄 Data Fetching
|
||||
|
||||
> [!NOTE]
|
||||
> You don't need to run the OpenAPI commands below to run the Front-end. You will only need to run them when adding or modifying endpoints on the Backend API and wanting to use those on the Frontend.
|
||||
|
||||
This project uses an auto-generated API client powered by [**Orval**](https://orval.dev/), which creates type-safe API clients from OpenAPI specifications.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Backend Requirements**: Each API endpoint needs a summary and tag in the OpenAPI spec
|
||||
2. **Operation ID Generation**: FastAPI generates operation IDs using the pattern `{method}{tag}{summary}`
|
||||
3. **Spec Fetching**: The OpenAPI spec is fetched from `http://localhost:8006/openapi.json` and saved to the frontend
|
||||
4. **Spec Transformation**: The OpenAPI spec is cleaned up using a custom transformer (see `autogpt_platform/frontend/src/app/api/transformers`)
|
||||
5. **Client Generation**: Auto-generated client includes TypeScript types, API endpoints, and Zod schemas, organized by tags
|
||||
|
||||
### API Client Commands
|
||||
|
||||
```bash
|
||||
# Fetch OpenAPI spec from backend and generate client
|
||||
pnpm generate:api
|
||||
|
||||
# Only fetch the OpenAPI spec
|
||||
pnpm fetch:openapi
|
||||
|
||||
# Only generate the client (after spec is fetched)
|
||||
pnpm generate:api-client
|
||||
```
|
||||
|
||||
### Using the Generated Client
|
||||
|
||||
The generated client provides React Query hooks for both queries and mutations:
|
||||
|
||||
#### Queries (GET requests)
|
||||
|
||||
```typescript
|
||||
import { useGetV1GetNotificationPreferences } from "@/app/api/__generated__/endpoints/auth/auth";
|
||||
|
||||
const { data, isLoading, isError } = useGetV1GetNotificationPreferences({
|
||||
query: {
|
||||
select: (res) => res.data,
|
||||
// Other React Query options
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
#### Mutations (POST, PUT, DELETE requests)
|
||||
|
||||
```typescript
|
||||
import { useDeleteV2DeleteStoreSubmission } from "@/app/api/__generated__/endpoints/store/store";
|
||||
import { getGetV2ListMySubmissionsQueryKey } from "@/app/api/__generated__/endpoints/store/store";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const { mutateAsync: deleteSubmission } = useDeleteV2DeleteStoreSubmission({
|
||||
mutation: {
|
||||
onSuccess: () => {
|
||||
// Invalidate related queries to refresh data
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListMySubmissionsQueryKey(),
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Usage
|
||||
await deleteSubmission({
|
||||
submissionId: submission_id,
|
||||
});
|
||||
```
|
||||
|
||||
#### Server Actions
|
||||
|
||||
For server-side operations, you can also use the generated client functions directly:
|
||||
|
||||
```typescript
|
||||
import { postV1UpdateNotificationPreferences } from "@/app/api/__generated__/endpoints/auth/auth";
|
||||
|
||||
// In a server action
|
||||
const preferences = {
|
||||
email: "user@example.com",
|
||||
preferences: {
|
||||
AGENT_RUN: true,
|
||||
ZERO_BALANCE: false,
|
||||
// ... other preferences
|
||||
},
|
||||
daily_limit: 0,
|
||||
};
|
||||
|
||||
await postV1UpdateNotificationPreferences(preferences);
|
||||
```
|
||||
|
||||
#### Server-Side Prefetching
|
||||
|
||||
For server-side components, you can prefetch data on the server and hydrate it in the client cache. This allows immediate access to cached data when queries are called:
|
||||
|
||||
```typescript
|
||||
import { getQueryClient } from "@/lib/tanstack-query/getQueryClient";
|
||||
import {
|
||||
prefetchGetV2ListStoreAgentsQuery,
|
||||
prefetchGetV2ListStoreCreatorsQuery
|
||||
} from "@/app/api/__generated__/endpoints/store/store";
|
||||
import { HydrationBoundary, dehydrate } from "@tanstack/react-query";
|
||||
|
||||
// In your server component
|
||||
const queryClient = getQueryClient();
|
||||
|
||||
await Promise.all([
|
||||
prefetchGetV2ListStoreAgentsQuery(queryClient, {
|
||||
featured: true,
|
||||
}),
|
||||
prefetchGetV2ListStoreAgentsQuery(queryClient, {
|
||||
sorted_by: "runs",
|
||||
}),
|
||||
prefetchGetV2ListStoreCreatorsQuery(queryClient, {
|
||||
featured: true,
|
||||
sorted_by: "num_agents",
|
||||
}),
|
||||
]);
|
||||
|
||||
return (
|
||||
<HydrationBoundary state={dehydrate(queryClient)}>
|
||||
<MainMarkeplacePage />
|
||||
</HydrationBoundary>
|
||||
);
|
||||
```
|
||||
|
||||
This pattern improves performance by serving pre-fetched data from the server while maintaining the benefits of client-side React Query features.
|
||||
|
||||
### Configuration
|
||||
|
||||
The Orval configuration is located in `autogpt_platform/frontend/orval.config.ts`. It generates two separate clients:
|
||||
|
||||
1. **autogpt_api_client**: React Query hooks for client-side data fetching
|
||||
2. **autogpt_zod_schema**: Zod schemas for validation
|
||||
|
||||
For more details, see the [Orval documentation](https://orval.dev/) or check the configuration file.
|
||||
See [CONTRIBUTING.md](./CONTRIBUTING.md) for guidance on generated API hooks, SSR + hydration patterns, and usage examples. You generally do not need to run OpenAPI commands unless adding/modifying backend endpoints.
|
||||
|
||||
## 🚩 Feature Flags
|
||||
|
||||
This project uses [LaunchDarkly](https://launchdarkly.com/) for feature flags, allowing us to control feature rollouts and A/B testing.
|
||||
|
||||
### Using Feature Flags
|
||||
|
||||
#### Check if a feature is enabled
|
||||
|
||||
```typescript
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
function MyComponent() {
|
||||
const isAgentActivityEnabled = useGetFlag(Flag.AGENT_ACTIVITY);
|
||||
|
||||
if (!isAgentActivityEnabled) {
|
||||
return null; // Hide feature
|
||||
}
|
||||
|
||||
return <div>Feature is enabled!</div>;
|
||||
}
|
||||
```
|
||||
|
||||
#### Protect entire components
|
||||
|
||||
```typescript
|
||||
import { withFeatureFlag } from "@/services/feature-flags/with-feature-flag";
|
||||
|
||||
const MyFeaturePage = withFeatureFlag(MyPageComponent, "my-feature-flag");
|
||||
```
|
||||
|
||||
### Testing with Feature Flags
|
||||
|
||||
For local development or running Playwright tests locally, use mocked feature flags by setting `NEXT_PUBLIC_PW_TEST=true` in your `.env` file. This bypasses LaunchDarkly and uses the mock values defined in the code.
|
||||
|
||||
### Adding New Flags
|
||||
|
||||
1. Add the flag to the `Flag` enum in `use-get-flag.ts`
|
||||
2. Add the flag type to `FlagValues` type
|
||||
3. Add mock value to `mockFlags` for testing
|
||||
4. Configure the flag in LaunchDarkly dashboard
|
||||
See [CONTRIBUTING.md](./CONTRIBUTING.md) for feature flag usage patterns, local development with mocks, and how to add new flags.
|
||||
|
||||
## 🚚 Deploy
|
||||
|
||||
@@ -333,7 +154,7 @@ By integrating Storybook into our development workflow, we can streamline UI dev
|
||||
- [**Tailwind CSS**](https://tailwindcss.com/) - Utility-first CSS framework
|
||||
- [**shadcn/ui**](https://ui.shadcn.com/) - Re-usable components built with Radix UI and Tailwind CSS
|
||||
- [**Radix UI**](https://www.radix-ui.com/) - Headless UI components for accessibility
|
||||
- [**Lucide React**](https://lucide.dev/guide/packages/lucide-react) - Beautiful & consistent icons
|
||||
- [**Phosphor Icons**](https://phosphoricons.com/) - Icon set used across the app
|
||||
- [**Framer Motion**](https://motion.dev/) - Animation library for React
|
||||
|
||||
### Development & Testing
|
||||
|
||||
@@ -2,18 +2,11 @@
|
||||
// The config you add here will be used whenever a users loads a page in their browser.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import {
|
||||
AppEnv,
|
||||
BehaveAs,
|
||||
getAppEnv,
|
||||
getBehaveAs,
|
||||
getEnvironmentStr,
|
||||
} from "@/lib/utils";
|
||||
import { environment } from "@/services/environment";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
const isProdOrDev = [AppEnv.PROD, AppEnv.DEV].includes(getAppEnv());
|
||||
|
||||
const isCloud = getBehaveAs() === BehaveAs.CLOUD;
|
||||
const isProdOrDev = environment.isProd() || environment.isDev();
|
||||
const isCloud = environment.isCloud();
|
||||
const isDisabled = process.env.DISABLE_SENTRY === "true";
|
||||
|
||||
const shouldEnable = !isDisabled && isProdOrDev && isCloud;
|
||||
@@ -21,7 +14,7 @@ const shouldEnable = !isDisabled && isProdOrDev && isCloud;
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
environment: getEnvironmentStr(),
|
||||
environment: environment.getEnvironmentStr(),
|
||||
|
||||
enabled: shouldEnable,
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@
|
||||
"@sentry/nextjs": "10.15.0",
|
||||
"@supabase/ssr": "0.6.1",
|
||||
"@supabase/supabase-js": "2.55.0",
|
||||
"@tanstack/react-query": "5.85.3",
|
||||
"@tanstack/react-query": "5.87.1",
|
||||
"@tanstack/react-table": "8.21.3",
|
||||
"@types/jaro-winkler": "0.2.4",
|
||||
"@vercel/analytics": "1.5.0",
|
||||
@@ -103,7 +103,7 @@
|
||||
"shepherd.js": "14.5.1",
|
||||
"sonner": "2.0.7",
|
||||
"tailwind-merge": "2.6.0",
|
||||
"tailwind-scrollbar": "4.0.2",
|
||||
"tailwind-scrollbar": "3.1.0",
|
||||
"tailwindcss-animate": "1.0.7",
|
||||
"uuid": "11.1.0",
|
||||
"vaul": "1.1.2",
|
||||
|
||||
144
autogpt_platform/frontend/pnpm-lock.yaml
generated
144
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -99,8 +99,8 @@ importers:
|
||||
specifier: 2.55.0
|
||||
version: 2.55.0
|
||||
'@tanstack/react-query':
|
||||
specifier: 5.85.3
|
||||
version: 5.85.3(react@18.3.1)
|
||||
specifier: 5.87.1
|
||||
version: 5.87.1(react@18.3.1)
|
||||
'@tanstack/react-table':
|
||||
specifier: 8.21.3
|
||||
version: 8.21.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
@@ -243,8 +243,8 @@ importers:
|
||||
specifier: 2.6.0
|
||||
version: 2.6.0
|
||||
tailwind-scrollbar:
|
||||
specifier: 4.0.2
|
||||
version: 4.0.2(react@18.3.1)(tailwindcss@3.4.17)
|
||||
specifier: 3.1.0
|
||||
version: 3.1.0(tailwindcss@3.4.17)
|
||||
tailwindcss-animate:
|
||||
specifier: 1.0.7
|
||||
version: 1.0.7(tailwindcss@3.4.17)
|
||||
@@ -287,7 +287,7 @@ importers:
|
||||
version: 5.86.0(eslint@8.57.1)(typescript@5.9.2)
|
||||
'@tanstack/react-query-devtools':
|
||||
specifier: 5.87.3
|
||||
version: 5.87.3(@tanstack/react-query@5.85.3(react@18.3.1))(react@18.3.1)
|
||||
version: 5.87.3(@tanstack/react-query@5.87.1(react@18.3.1))(react@18.3.1)
|
||||
'@types/canvas-confetti':
|
||||
specifier: 1.9.0
|
||||
version: 1.9.0
|
||||
@@ -947,10 +947,6 @@ packages:
|
||||
peerDependencies:
|
||||
'@babel/core': ^7.0.0-0
|
||||
|
||||
'@babel/runtime@7.28.3':
|
||||
resolution: {integrity: sha512-9uIQ10o0WGdpP6GDhXcdOJPJuDgFtIDtN/9+ArJQ2NAfAmiuhTQdzkaTGR33v43GYS2UrSA0eX2pPPHoFVvpxA==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
|
||||
'@babel/runtime@7.28.4':
|
||||
resolution: {integrity: sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==}
|
||||
engines: {node: '>=6.9.0'}
|
||||
@@ -985,9 +981,6 @@ packages:
|
||||
'@emnapi/core@1.5.0':
|
||||
resolution: {integrity: sha512-sbP8GzB1WDzacS8fgNPpHlp6C9VZe+SJP3F90W9rLemaQj2PzIuTEl1qDOYQf58YIpyjViI24y9aPWCjEzY2cg==}
|
||||
|
||||
'@emnapi/runtime@1.4.5':
|
||||
resolution: {integrity: sha512-++LApOtY0pEEz1zrd9vy1/zXVaVJJ/EbAF3u0fXIzPJEDtnITsBGbbK0EkM72amhl/R5b+5xx0Y/QhcVOpuulg==}
|
||||
|
||||
'@emnapi/runtime@1.5.0':
|
||||
resolution: {integrity: sha512-97/BJ3iXHww3djw6hYIfErCZFee7qCtrneuLa20UXFCOTCfBM2cvQHjWJ2EG0s0MtdNwInarqCTz35i4wWXHsQ==}
|
||||
|
||||
@@ -1159,12 +1152,6 @@ packages:
|
||||
cpu: [x64]
|
||||
os: [win32]
|
||||
|
||||
'@eslint-community/eslint-utils@4.7.0':
|
||||
resolution: {integrity: sha512-dyybb3AcajC7uha6CvhdVRJqaKyn7w2YKqKyAN37NKYgZT36w+iRb0Dymmc5qEJ549c/S31cMMSFd75bteCpCw==}
|
||||
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
|
||||
peerDependencies:
|
||||
eslint: ^6.0.0 || ^7.0.0 || >=8.0.0
|
||||
|
||||
'@eslint-community/eslint-utils@4.9.0':
|
||||
resolution: {integrity: sha512-ayVFHdtZ+hsq1t2Dy24wCmGXGe4q9Gu3smhLYALJrr473ZH27MsnSL+LKUlimp4BWJqMDMLmPpx/Q9R3OAlL4g==}
|
||||
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
|
||||
@@ -2856,8 +2843,8 @@ packages:
|
||||
peerDependencies:
|
||||
eslint: ^8.57.0 || ^9.0.0
|
||||
|
||||
'@tanstack/query-core@5.85.3':
|
||||
resolution: {integrity: sha512-9Ne4USX83nHmRuEYs78LW+3lFEEO2hBDHu7mrdIgAFx5Zcrs7ker3n/i8p4kf6OgKExmaDN5oR0efRD7i2J0DQ==}
|
||||
'@tanstack/query-core@5.87.1':
|
||||
resolution: {integrity: sha512-HOFHVvhOCprrWvtccSzc7+RNqpnLlZ5R6lTmngb8aq7b4rc2/jDT0w+vLdQ4lD9bNtQ+/A4GsFXy030Gk4ollA==}
|
||||
|
||||
'@tanstack/query-devtools@5.87.3':
|
||||
resolution: {integrity: sha512-LkzxzSr2HS1ALHTgDmJH5eGAVsSQiuwz//VhFW5OqNk0OQ+Fsqba0Tsf+NzWRtXYvpgUqwQr4b2zdFZwxHcGvg==}
|
||||
@@ -2868,8 +2855,8 @@ packages:
|
||||
'@tanstack/react-query': ^5.87.1
|
||||
react: ^18 || ^19
|
||||
|
||||
'@tanstack/react-query@5.85.3':
|
||||
resolution: {integrity: sha512-AqU8TvNh5GVIE8I+TUU0noryBRy7gOY0XhSayVXmOPll4UkZeLWKDwi0rtWOZbwLRCbyxorfJ5DIjDqE7GXpcQ==}
|
||||
'@tanstack/react-query@5.87.1':
|
||||
resolution: {integrity: sha512-YKauf8jfMowgAqcxj96AHs+Ux3m3bWT1oSVKamaRPXSnW2HqSznnTCEkAVqctF1e/W9R/mPcyzzINIgpOH94qg==}
|
||||
peerDependencies:
|
||||
react: ^18 || ^19
|
||||
|
||||
@@ -3045,9 +3032,6 @@ packages:
|
||||
'@types/phoenix@1.6.6':
|
||||
resolution: {integrity: sha512-PIzZZlEppgrpoT2QgbnDU+MMzuR6BbCjllj0bM70lWoejMeNJAxCchxnv7J3XFkI8MpygtRpzXrIlmWUBclP5A==}
|
||||
|
||||
'@types/prismjs@1.26.5':
|
||||
resolution: {integrity: sha512-AUZTa7hQ2KY5L7AmtSiqxlhWxb4ina0yd8hNbl4TWuqnv/pFP0nDMb3YrfSBf4hJVGLh2YEIBfKaBW/9UEl6IQ==}
|
||||
|
||||
'@types/prop-types@15.7.15':
|
||||
resolution: {integrity: sha512-F6bEyamV9jKGAFBEmlQnesRPGOQqS2+Uwi0Em15xenOxHaf2hv6L8YCVn3rPdPJOiJfPiCnLIRyvwVaqMY3MIw==}
|
||||
|
||||
@@ -3740,9 +3724,6 @@ packages:
|
||||
camelize@1.0.1:
|
||||
resolution: {integrity: sha512-dU+Tx2fsypxTgtLoE36npi3UqcjSSMNYfkqgmoEhtZrraP5VWq0K7FkWVTYa8eMPtnU/G2txVsfdCJTn9uzpuQ==}
|
||||
|
||||
caniuse-lite@1.0.30001735:
|
||||
resolution: {integrity: sha512-EV/laoX7Wq2J9TQlyIXRxTJqIw4sxfXS4OYgudGxBYRuTv0q7AM6yMEpU/Vo1I94thg9U6EZ2NfZx9GJq83u7w==}
|
||||
|
||||
caniuse-lite@1.0.30001741:
|
||||
resolution: {integrity: sha512-QGUGitqsc8ARjLdgAfxETDhRbJ0REsP6O3I96TAth/mVjh2cYzN2u+3AzPP3aVSm2FehEItaJw1xd+IGBXWeSw==}
|
||||
|
||||
@@ -4108,15 +4089,6 @@ packages:
|
||||
supports-color:
|
||||
optional: true
|
||||
|
||||
debug@4.4.1:
|
||||
resolution: {integrity: sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ==}
|
||||
engines: {node: '>=6.0'}
|
||||
peerDependencies:
|
||||
supports-color: '*'
|
||||
peerDependenciesMeta:
|
||||
supports-color:
|
||||
optional: true
|
||||
|
||||
debug@4.4.3:
|
||||
resolution: {integrity: sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==}
|
||||
engines: {node: '>=6.0'}
|
||||
@@ -6220,11 +6192,6 @@ packages:
|
||||
resolution: {integrity: sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==}
|
||||
engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0}
|
||||
|
||||
prism-react-renderer@2.4.1:
|
||||
resolution: {integrity: sha512-ey8Ls/+Di31eqzUxC46h8MksNuGx/n0AAC8uKpwFau4RPDYLuE3EXTp8N8G2vX2N7UC/+IXeNUnlWBGGcAG+Ig==}
|
||||
peerDependencies:
|
||||
react: '>=16.0.0'
|
||||
|
||||
process-nextick-args@2.0.1:
|
||||
resolution: {integrity: sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==}
|
||||
|
||||
@@ -6949,11 +6916,11 @@ packages:
|
||||
tailwind-merge@2.6.0:
|
||||
resolution: {integrity: sha512-P+Vu1qXfzediirmHOC3xKGAYeZtPcV9g76X+xg2FD4tYgR71ewMA35Y3sCz3zhiN/dwefRpJX0yBcgwi1fXNQA==}
|
||||
|
||||
tailwind-scrollbar@4.0.2:
|
||||
resolution: {integrity: sha512-wAQiIxAPqk0MNTPptVe/xoyWi27y+NRGnTwvn4PQnbvB9kp8QUBiGl/wsfoVBHnQxTmhXJSNt9NHTmcz9EivFA==}
|
||||
tailwind-scrollbar@3.1.0:
|
||||
resolution: {integrity: sha512-pmrtDIZeHyu2idTejfV59SbaJyvp1VRjYxAjZBH0jnyrPRo6HL1kD5Glz8VPagasqr6oAx6M05+Tuw429Z8jxg==}
|
||||
engines: {node: '>=12.13.0'}
|
||||
peerDependencies:
|
||||
tailwindcss: 4.x
|
||||
tailwindcss: 3.x
|
||||
|
||||
tailwindcss-animate@1.0.7:
|
||||
resolution: {integrity: sha512-bl6mpH3T7I3UFxuvDEXLxy/VuFxBk5bbzplh7tXI68mwMokNYd1t9qPBHlnyTwfa4JGC4zP516I1hYYtQ/vspA==}
|
||||
@@ -7567,7 +7534,7 @@ snapshots:
|
||||
'@babel/types': 7.28.4
|
||||
'@jridgewell/remapping': 2.3.5
|
||||
convert-source-map: 2.0.0
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
gensync: 1.0.0-beta.2
|
||||
json5: 2.2.3
|
||||
semver: 6.3.1
|
||||
@@ -7619,7 +7586,7 @@ snapshots:
|
||||
'@babel/core': 7.28.4
|
||||
'@babel/helper-compilation-targets': 7.27.2
|
||||
'@babel/helper-plugin-utils': 7.27.1
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
lodash.debounce: 4.0.8
|
||||
resolve: 1.22.10
|
||||
transitivePeerDependencies:
|
||||
@@ -8270,8 +8237,6 @@ snapshots:
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@babel/runtime@7.28.3': {}
|
||||
|
||||
'@babel/runtime@7.28.4': {}
|
||||
|
||||
'@babel/template@7.27.2':
|
||||
@@ -8288,7 +8253,7 @@ snapshots:
|
||||
'@babel/parser': 7.28.4
|
||||
'@babel/template': 7.27.2
|
||||
'@babel/types': 7.28.4
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
@@ -8325,11 +8290,6 @@ snapshots:
|
||||
tslib: 2.8.1
|
||||
optional: true
|
||||
|
||||
'@emnapi/runtime@1.4.5':
|
||||
dependencies:
|
||||
tslib: 2.8.1
|
||||
optional: true
|
||||
|
||||
'@emnapi/runtime@1.5.0':
|
||||
dependencies:
|
||||
tslib: 2.8.1
|
||||
@@ -8426,11 +8386,6 @@ snapshots:
|
||||
'@esbuild/win32-x64@0.25.9':
|
||||
optional: true
|
||||
|
||||
'@eslint-community/eslint-utils@4.7.0(eslint@8.57.1)':
|
||||
dependencies:
|
||||
eslint: 8.57.1
|
||||
eslint-visitor-keys: 3.4.3
|
||||
|
||||
'@eslint-community/eslint-utils@4.9.0(eslint@8.57.1)':
|
||||
dependencies:
|
||||
eslint: 8.57.1
|
||||
@@ -8441,7 +8396,7 @@ snapshots:
|
||||
'@eslint/eslintrc@2.1.4':
|
||||
dependencies:
|
||||
ajv: 6.12.6
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
espree: 9.6.1
|
||||
globals: 13.24.0
|
||||
ignore: 5.3.2
|
||||
@@ -8491,7 +8446,7 @@ snapshots:
|
||||
'@humanwhocodes/config-array@0.13.0':
|
||||
dependencies:
|
||||
'@humanwhocodes/object-schema': 2.0.3
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
minimatch: 3.1.2
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -8592,7 +8547,7 @@ snapshots:
|
||||
|
||||
'@img/sharp-wasm32@0.34.3':
|
||||
dependencies:
|
||||
'@emnapi/runtime': 1.4.5
|
||||
'@emnapi/runtime': 1.5.0
|
||||
optional: true
|
||||
|
||||
'@img/sharp-win32-arm64@0.34.3':
|
||||
@@ -9041,7 +8996,7 @@ snapshots:
|
||||
ajv: 8.17.1
|
||||
chalk: 4.1.2
|
||||
compare-versions: 6.1.1
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
esbuild: 0.25.9
|
||||
esutils: 2.0.3
|
||||
fs-extra: 11.3.1
|
||||
@@ -10373,7 +10328,7 @@ snapshots:
|
||||
|
||||
'@storybook/react-docgen-typescript-plugin@1.0.6--canary.9.0c3f3b7.0(typescript@5.9.2)(webpack@5.101.3(esbuild@0.25.9))':
|
||||
dependencies:
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
endent: 2.1.0
|
||||
find-cache-dir: 3.3.2
|
||||
flat-cache: 3.2.0
|
||||
@@ -10460,19 +10415,19 @@ snapshots:
|
||||
- supports-color
|
||||
- typescript
|
||||
|
||||
'@tanstack/query-core@5.85.3': {}
|
||||
'@tanstack/query-core@5.87.1': {}
|
||||
|
||||
'@tanstack/query-devtools@5.87.3': {}
|
||||
|
||||
'@tanstack/react-query-devtools@5.87.3(@tanstack/react-query@5.85.3(react@18.3.1))(react@18.3.1)':
|
||||
'@tanstack/react-query-devtools@5.87.3(@tanstack/react-query@5.87.1(react@18.3.1))(react@18.3.1)':
|
||||
dependencies:
|
||||
'@tanstack/query-devtools': 5.87.3
|
||||
'@tanstack/react-query': 5.85.3(react@18.3.1)
|
||||
'@tanstack/react-query': 5.87.1(react@18.3.1)
|
||||
react: 18.3.1
|
||||
|
||||
'@tanstack/react-query@5.85.3(react@18.3.1)':
|
||||
'@tanstack/react-query@5.87.1(react@18.3.1)':
|
||||
dependencies:
|
||||
'@tanstack/query-core': 5.85.3
|
||||
'@tanstack/query-core': 5.87.1
|
||||
react: 18.3.1
|
||||
|
||||
'@tanstack/react-table@8.21.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
|
||||
@@ -10664,8 +10619,6 @@ snapshots:
|
||||
|
||||
'@types/phoenix@1.6.6': {}
|
||||
|
||||
'@types/prismjs@1.26.5': {}
|
||||
|
||||
'@types/prop-types@15.7.15': {}
|
||||
|
||||
'@types/react-dom@18.3.5(@types/react@18.3.17)':
|
||||
@@ -10734,7 +10687,7 @@ snapshots:
|
||||
'@typescript-eslint/types': 8.43.0
|
||||
'@typescript-eslint/typescript-estree': 8.43.0(typescript@5.9.2)
|
||||
'@typescript-eslint/visitor-keys': 8.43.0
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
eslint: 8.57.1
|
||||
typescript: 5.9.2
|
||||
transitivePeerDependencies:
|
||||
@@ -10744,7 +10697,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@typescript-eslint/tsconfig-utils': 8.43.0(typescript@5.9.2)
|
||||
'@typescript-eslint/types': 8.43.0
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
typescript: 5.9.2
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -10763,7 +10716,7 @@ snapshots:
|
||||
'@typescript-eslint/types': 8.43.0
|
||||
'@typescript-eslint/typescript-estree': 8.43.0(typescript@5.9.2)
|
||||
'@typescript-eslint/utils': 8.43.0(eslint@8.57.1)(typescript@5.9.2)
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
eslint: 8.57.1
|
||||
ts-api-utils: 2.1.0(typescript@5.9.2)
|
||||
typescript: 5.9.2
|
||||
@@ -10778,7 +10731,7 @@ snapshots:
|
||||
'@typescript-eslint/tsconfig-utils': 8.43.0(typescript@5.9.2)
|
||||
'@typescript-eslint/types': 8.43.0
|
||||
'@typescript-eslint/visitor-keys': 8.43.0
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
fast-glob: 3.3.3
|
||||
is-glob: 4.0.3
|
||||
minimatch: 9.0.5
|
||||
@@ -11395,8 +11348,6 @@ snapshots:
|
||||
|
||||
camelize@1.0.1: {}
|
||||
|
||||
caniuse-lite@1.0.30001735: {}
|
||||
|
||||
caniuse-lite@1.0.30001741: {}
|
||||
|
||||
case-sensitive-paths-webpack-plugin@2.4.0: {}
|
||||
@@ -11598,7 +11549,7 @@ snapshots:
|
||||
dependencies:
|
||||
cipher-base: 1.0.6
|
||||
inherits: 2.0.4
|
||||
ripemd160: 2.0.1
|
||||
ripemd160: 2.0.2
|
||||
sha.js: 2.4.12
|
||||
|
||||
create-hash@1.2.0:
|
||||
@@ -11612,9 +11563,9 @@ snapshots:
|
||||
create-hmac@1.1.7:
|
||||
dependencies:
|
||||
cipher-base: 1.0.6
|
||||
create-hash: 1.1.3
|
||||
create-hash: 1.2.0
|
||||
inherits: 2.0.4
|
||||
ripemd160: 2.0.1
|
||||
ripemd160: 2.0.2
|
||||
safe-buffer: 5.2.1
|
||||
sha.js: 2.4.12
|
||||
|
||||
@@ -11772,10 +11723,6 @@ snapshots:
|
||||
dependencies:
|
||||
ms: 2.1.3
|
||||
|
||||
debug@4.4.1:
|
||||
dependencies:
|
||||
ms: 2.1.3
|
||||
|
||||
debug@4.4.3:
|
||||
dependencies:
|
||||
ms: 2.1.3
|
||||
@@ -12077,7 +12024,7 @@ snapshots:
|
||||
|
||||
esbuild-register@3.6.0(esbuild@0.25.9):
|
||||
dependencies:
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
esbuild: 0.25.9
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -12148,7 +12095,7 @@ snapshots:
|
||||
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1):
|
||||
dependencies:
|
||||
'@nolyfill/is-core-module': 1.0.39
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
eslint: 8.57.1
|
||||
get-tsconfig: 4.10.1
|
||||
is-bun-module: 2.0.0
|
||||
@@ -12270,7 +12217,7 @@ snapshots:
|
||||
|
||||
eslint@8.57.1:
|
||||
dependencies:
|
||||
'@eslint-community/eslint-utils': 4.7.0(eslint@8.57.1)
|
||||
'@eslint-community/eslint-utils': 4.9.0(eslint@8.57.1)
|
||||
'@eslint-community/regexpp': 4.12.1
|
||||
'@eslint/eslintrc': 2.1.4
|
||||
'@eslint/js': 8.57.1
|
||||
@@ -12281,7 +12228,7 @@ snapshots:
|
||||
ajv: 6.12.6
|
||||
chalk: 4.1.2
|
||||
cross-spawn: 7.0.6
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
doctrine: 3.0.0
|
||||
escape-string-regexp: 4.0.0
|
||||
eslint-scope: 7.2.2
|
||||
@@ -13654,7 +13601,7 @@ snapshots:
|
||||
micromark@4.0.2:
|
||||
dependencies:
|
||||
'@types/debug': 4.1.12
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
decode-named-character-reference: 1.2.0
|
||||
devlop: 1.1.0
|
||||
micromark-core-commonmark: 2.0.3
|
||||
@@ -13790,7 +13737,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@next/env': 15.4.7
|
||||
'@swc/helpers': 0.5.15
|
||||
caniuse-lite: 1.0.30001735
|
||||
caniuse-lite: 1.0.30001741
|
||||
postcss: 8.4.31
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
@@ -14311,12 +14258,6 @@ snapshots:
|
||||
ansi-styles: 5.2.0
|
||||
react-is: 17.0.2
|
||||
|
||||
prism-react-renderer@2.4.1(react@18.3.1):
|
||||
dependencies:
|
||||
'@types/prismjs': 1.26.5
|
||||
clsx: 2.1.1
|
||||
react: 18.3.1
|
||||
|
||||
process-nextick-args@2.0.1: {}
|
||||
|
||||
process@0.11.10: {}
|
||||
@@ -14495,7 +14436,7 @@ snapshots:
|
||||
|
||||
react-window@1.8.11(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
dependencies:
|
||||
'@babel/runtime': 7.28.3
|
||||
'@babel/runtime': 7.28.4
|
||||
memoize-one: 5.2.1
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
@@ -14716,7 +14657,7 @@ snapshots:
|
||||
|
||||
require-in-the-middle@7.5.2:
|
||||
dependencies:
|
||||
debug: 4.4.1
|
||||
debug: 4.4.3
|
||||
module-details-from-path: 1.0.4
|
||||
resolve: 1.22.10
|
||||
transitivePeerDependencies:
|
||||
@@ -15259,12 +15200,9 @@ snapshots:
|
||||
|
||||
tailwind-merge@2.6.0: {}
|
||||
|
||||
tailwind-scrollbar@4.0.2(react@18.3.1)(tailwindcss@3.4.17):
|
||||
tailwind-scrollbar@3.1.0(tailwindcss@3.4.17):
|
||||
dependencies:
|
||||
prism-react-renderer: 2.4.1(react@18.3.1)
|
||||
tailwindcss: 3.4.17
|
||||
transitivePeerDependencies:
|
||||
- react
|
||||
|
||||
tailwindcss-animate@1.0.7(tailwindcss@3.4.17):
|
||||
dependencies:
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
import { getAgptServerBaseUrl } from "@/lib/env-config";
|
||||
import { execSync } from "child_process";
|
||||
import * as path from "path";
|
||||
import * as fs from "fs";
|
||||
import * as os from "os";
|
||||
import { environment } from "@/services/environment";
|
||||
|
||||
function fetchOpenApiSpec(): void {
|
||||
const args = process.argv.slice(2);
|
||||
const forceFlag = args.includes("--force");
|
||||
|
||||
const baseUrl = getAgptServerBaseUrl();
|
||||
const baseUrl = environment.getAGPTServerBaseUrl();
|
||||
const openApiUrl = `${baseUrl}/openapi.json`;
|
||||
const outputPath = path.join(
|
||||
__dirname,
|
||||
|
||||
@@ -3,18 +3,11 @@
|
||||
// Note that this config is unrelated to the Vercel Edge Runtime and is also required when running locally.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import { environment } from "@/services/environment";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import {
|
||||
AppEnv,
|
||||
BehaveAs,
|
||||
getAppEnv,
|
||||
getBehaveAs,
|
||||
getEnvironmentStr,
|
||||
} from "./src/lib/utils";
|
||||
|
||||
const isProdOrDev = [AppEnv.PROD, AppEnv.DEV].includes(getAppEnv());
|
||||
|
||||
const isCloud = getBehaveAs() === BehaveAs.CLOUD;
|
||||
const isProdOrDev = environment.isProd() || environment.isDev();
|
||||
const isCloud = environment.isCloud();
|
||||
const isDisabled = process.env.DISABLE_SENTRY === "true";
|
||||
|
||||
const shouldEnable = !isDisabled && isProdOrDev && isCloud;
|
||||
@@ -22,7 +15,7 @@ const shouldEnable = !isDisabled && isProdOrDev && isCloud;
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
environment: getEnvironmentStr(),
|
||||
environment: environment.getEnvironmentStr(),
|
||||
|
||||
enabled: shouldEnable,
|
||||
|
||||
@@ -40,7 +33,7 @@ Sentry.init({
|
||||
|
||||
enableLogs: true,
|
||||
integrations: [
|
||||
Sentry.captureConsoleIntegration(),
|
||||
Sentry.captureConsoleIntegration({ levels: ["fatal", "error", "warn"] }),
|
||||
Sentry.extraErrorDataIntegration(),
|
||||
],
|
||||
});
|
||||
|
||||
@@ -2,19 +2,12 @@
|
||||
// The config you add here will be used whenever the server handles a request.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import {
|
||||
AppEnv,
|
||||
BehaveAs,
|
||||
getAppEnv,
|
||||
getBehaveAs,
|
||||
getEnvironmentStr,
|
||||
} from "@/lib/utils";
|
||||
import { environment } from "@/services/environment";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
// import { NodeProfilingIntegration } from "@sentry/profiling-node";
|
||||
|
||||
const isProdOrDev = [AppEnv.PROD, AppEnv.DEV].includes(getAppEnv());
|
||||
|
||||
const isCloud = getBehaveAs() === BehaveAs.CLOUD;
|
||||
const isProdOrDev = environment.isProd() || environment.isDev();
|
||||
const isCloud = environment.isCloud();
|
||||
const isDisabled = process.env.DISABLE_SENTRY === "true";
|
||||
|
||||
const shouldEnable = !isDisabled && isProdOrDev && isCloud;
|
||||
@@ -22,7 +15,7 @@ const shouldEnable = !isDisabled && isProdOrDev && isCloud;
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
environment: getEnvironmentStr(),
|
||||
environment: environment.getEnvironmentStr(),
|
||||
|
||||
enabled: shouldEnable,
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ import OnboardingAgentCard from "../components/OnboardingAgentCard";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { StoreAgentDetails } from "@/lib/autogpt-server-api";
|
||||
import { finishOnboarding } from "../6-congrats/actions";
|
||||
import { isEmptyOrWhitespace } from "@/lib/utils";
|
||||
import { useOnboarding } from "../../../../providers/onboarding/onboarding-provider";
|
||||
import { finishOnboarding } from "../6-congrats/actions";
|
||||
|
||||
export default function Page() {
|
||||
const { state, updateState } = useOnboarding(4, "INTEGRATIONS");
|
||||
@@ -24,6 +24,7 @@ export default function Page() {
|
||||
if (agents.length < 2) {
|
||||
finishOnboarding();
|
||||
}
|
||||
|
||||
setAgents(agents);
|
||||
});
|
||||
}, [api, setAgents]);
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
|
||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
import { useState } from "react";
|
||||
import { getSchemaDefaultCredentials } from "../../helpers";
|
||||
import { areAllCredentialsSet, getCredentialFields } from "./helpers";
|
||||
|
||||
type Credential = CredentialsMetaInput | undefined;
|
||||
type Credentials = Record<string, Credential>;
|
||||
|
||||
type Props = {
|
||||
agent: GraphMeta | null;
|
||||
siblingInputs?: Record<string, any>;
|
||||
onCredentialsChange: (
|
||||
credentials: Record<string, CredentialsMetaInput>,
|
||||
) => void;
|
||||
onValidationChange: (isValid: boolean) => void;
|
||||
onLoadingChange: (isLoading: boolean) => void;
|
||||
};
|
||||
|
||||
export function AgentOnboardingCredentials(props: Props) {
|
||||
const [inputCredentials, setInputCredentials] = useState<Credentials>({});
|
||||
|
||||
const fields = getCredentialFields(props.agent);
|
||||
const required = Object.keys(fields || {}).length > 0;
|
||||
|
||||
if (!required) return null;
|
||||
|
||||
function handleSelectCredentials(key: string, value: Credential) {
|
||||
const updated = { ...inputCredentials, [key]: value };
|
||||
setInputCredentials(updated);
|
||||
|
||||
const sanitized: Record<string, CredentialsMetaInput> = {};
|
||||
for (const [k, v] of Object.entries(updated)) {
|
||||
if (v) sanitized[k] = v;
|
||||
}
|
||||
|
||||
props.onCredentialsChange(sanitized);
|
||||
|
||||
const isValid = !required || areAllCredentialsSet(fields, updated);
|
||||
props.onValidationChange(isValid);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{Object.entries(fields).map(([key, inputSubSchema]) => (
|
||||
<div key={key} className="mt-4">
|
||||
<CredentialsInput
|
||||
schema={inputSubSchema}
|
||||
selectedCredentials={
|
||||
inputCredentials[key] ??
|
||||
getSchemaDefaultCredentials(inputSubSchema)
|
||||
}
|
||||
onSelectCredentials={(value) => handleSelectCredentials(key, value)}
|
||||
siblingInputs={props.siblingInputs}
|
||||
onLoaded={(loaded) => props.onLoadingChange(!loaded)}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
|
||||
|
||||
export function getCredentialFields(
|
||||
agent: GraphMeta | null,
|
||||
): AgentCredentialsFields {
|
||||
if (!agent) return {};
|
||||
|
||||
const hasNoInputs =
|
||||
!agent.credentials_input_schema ||
|
||||
typeof agent.credentials_input_schema !== "object" ||
|
||||
!("properties" in agent.credentials_input_schema) ||
|
||||
!agent.credentials_input_schema.properties;
|
||||
|
||||
if (hasNoInputs) return {};
|
||||
|
||||
return agent.credentials_input_schema.properties as AgentCredentialsFields;
|
||||
}
|
||||
|
||||
export type AgentCredentialsFields = Record<
|
||||
string,
|
||||
BlockIOCredentialsSubSchema
|
||||
>;
|
||||
|
||||
export function areAllCredentialsSet(
|
||||
fields: AgentCredentialsFields,
|
||||
inputs: Record<string, CredentialsMetaInput | undefined>,
|
||||
) {
|
||||
const required = Object.keys(fields || {});
|
||||
return required.every((k) => Boolean(inputs[k]));
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { OnboardingText } from "../../components/OnboardingText";
|
||||
|
||||
type RunAgentHintProps = {
|
||||
handleNewRun: () => void;
|
||||
};
|
||||
|
||||
export function RunAgentHint(props: RunAgentHintProps) {
|
||||
return (
|
||||
<div className="ml-[104px] w-[481px] pl-5">
|
||||
<div className="flex flex-col">
|
||||
<OnboardingText variant="header">Run your first agent</OnboardingText>
|
||||
<span className="mt-9 text-base font-normal leading-normal text-zinc-600">
|
||||
A 'run' is when your agent starts working on a task
|
||||
</span>
|
||||
<span className="mt-4 text-base font-normal leading-normal text-zinc-600">
|
||||
Click on <b>New Run</b> below to try it out
|
||||
</span>
|
||||
|
||||
<div
|
||||
onClick={props.handleNewRun}
|
||||
className={cn(
|
||||
"mt-16 flex h-[68px] w-[330px] items-center justify-center rounded-xl border-2 border-violet-700 bg-neutral-50",
|
||||
"cursor-pointer transition-all duration-200 ease-in-out hover:bg-violet-50",
|
||||
)}
|
||||
>
|
||||
<svg
|
||||
width="38"
|
||||
height="38"
|
||||
viewBox="0 0 32 32"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<g stroke="#6d28d9" strokeWidth="1.2" strokeLinecap="round">
|
||||
<line x1="16" y1="8" x2="16" y2="24" />
|
||||
<line x1="8" y1="16" x2="24" y2="16" />
|
||||
</g>
|
||||
</svg>
|
||||
<span className="ml-3 font-sans text-[19px] font-medium leading-normal text-violet-700">
|
||||
New run
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
import { StoreAgentDetails } from "@/app/api/__generated__/models/storeAgentDetails";
|
||||
import StarRating from "../../components/StarRating";
|
||||
import SmartImage from "@/components/__legacy__/SmartImage";
|
||||
|
||||
type Props = {
|
||||
storeAgent: StoreAgentDetails | null;
|
||||
};
|
||||
|
||||
export function SelectedAgentCard(props: Props) {
|
||||
return (
|
||||
<div className="fixed left-1/4 top-1/2 w-[481px] -translate-x-1/2 -translate-y-1/2">
|
||||
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
|
||||
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
|
||||
SELECTED AGENT
|
||||
</span>
|
||||
{props.storeAgent ? (
|
||||
<div className="mt-4 flex h-20 rounded-lg bg-violet-50 p-3">
|
||||
{/* Left image */}
|
||||
<SmartImage
|
||||
src={props.storeAgent.agent_image[0]}
|
||||
alt="Agent cover"
|
||||
className="w-[350px] rounded-lg"
|
||||
/>
|
||||
{/* Right content */}
|
||||
<div className="ml-3 flex flex-1 flex-col">
|
||||
<div className="mb-2 flex flex-col items-start">
|
||||
<span className="w-[292px] truncate font-sans text-[14px] font-medium leading-tight text-zinc-800">
|
||||
{props.storeAgent.agent_name}
|
||||
</span>
|
||||
<span className="font-norma w-[292px] truncate font-sans text-xs text-zinc-600">
|
||||
by {props.storeAgent.creator}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex w-[292px] items-center justify-between">
|
||||
<span className="truncate font-sans text-xs font-normal leading-tight text-zinc-600">
|
||||
{props.storeAgent.runs.toLocaleString("en-US")} runs
|
||||
</span>
|
||||
<StarRating
|
||||
className="font-sans text-xs font-normal leading-tight text-zinc-600"
|
||||
starSize={12}
|
||||
rating={props.storeAgent.rating || 0}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="mt-4 flex h-20 animate-pulse rounded-lg bg-gray-300 p-2" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import type {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import type { InputValues } from "./types";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
|
||||
export function computeInitialAgentInputs(
|
||||
agent: GraphMeta | null,
|
||||
@@ -21,7 +21,6 @@ export function computeInitialAgentInputs(
|
||||
result[key] = existingInputs[key];
|
||||
return;
|
||||
}
|
||||
// GraphIOSubSchema.default is typed as string, but server may return other primitives
|
||||
const def = (subSchema as unknown as { default?: string | number }).default;
|
||||
result[key] = def ?? "";
|
||||
});
|
||||
@@ -29,40 +28,20 @@ export function computeInitialAgentInputs(
|
||||
return result;
|
||||
}
|
||||
|
||||
export function getAgentCredentialsInputFields(agent: GraphMeta | null) {
|
||||
const hasNoInputs =
|
||||
!agent?.credentials_input_schema ||
|
||||
typeof agent.credentials_input_schema !== "object" ||
|
||||
!("properties" in agent.credentials_input_schema) ||
|
||||
!agent.credentials_input_schema.properties;
|
||||
|
||||
if (hasNoInputs) return {};
|
||||
|
||||
return agent.credentials_input_schema.properties;
|
||||
}
|
||||
|
||||
export function areAllCredentialsSet(
|
||||
fields: Record<string, BlockIOCredentialsSubSchema>,
|
||||
inputs: Record<string, CredentialsMetaInput | undefined>,
|
||||
) {
|
||||
const required = Object.keys(fields || {});
|
||||
return required.every((k) => Boolean(inputs[k]));
|
||||
}
|
||||
|
||||
type IsRunDisabledParams = {
|
||||
agent: GraphMeta | null;
|
||||
isRunning: boolean;
|
||||
agentInputs: InputValues | null | undefined;
|
||||
credentialsRequired: boolean;
|
||||
credentialsSatisfied: boolean;
|
||||
credentialsValid: boolean;
|
||||
credentialsLoaded: boolean;
|
||||
};
|
||||
|
||||
export function isRunDisabled({
|
||||
agent,
|
||||
isRunning,
|
||||
agentInputs,
|
||||
credentialsRequired,
|
||||
credentialsSatisfied,
|
||||
credentialsValid,
|
||||
credentialsLoaded,
|
||||
}: IsRunDisabledParams) {
|
||||
const hasEmptyInput = Object.values(agentInputs || {}).some(
|
||||
(value) => String(value).trim() === "",
|
||||
@@ -71,7 +50,8 @@ export function isRunDisabled({
|
||||
if (hasEmptyInput) return true;
|
||||
if (!agent) return true;
|
||||
if (isRunning) return true;
|
||||
if (credentialsRequired && !credentialsSatisfied) return true;
|
||||
if (!credentialsValid) return true;
|
||||
if (!credentialsLoaded) return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -81,13 +61,3 @@ export function getSchemaDefaultCredentials(
|
||||
): CredentialsMetaInput | undefined {
|
||||
return schema.default as CredentialsMetaInput | undefined;
|
||||
}
|
||||
|
||||
export function sanitizeCredentials(
|
||||
map: Record<string, CredentialsMetaInput | undefined>,
|
||||
): Record<string, CredentialsMetaInput> {
|
||||
const sanitized: Record<string, CredentialsMetaInput> = {};
|
||||
for (const [key, value] of Object.entries(map)) {
|
||||
if (value) sanitized[key] = value;
|
||||
}
|
||||
return sanitized;
|
||||
}
|
||||
|
||||
@@ -1,224 +1,66 @@
|
||||
"use client";
|
||||
import SmartImage from "@/components/__legacy__/SmartImage";
|
||||
import { useOnboarding } from "../../../../providers/onboarding/onboarding-provider";
|
||||
|
||||
import OnboardingButton from "../components/OnboardingButton";
|
||||
import { OnboardingHeader, OnboardingStep } from "../components/OnboardingStep";
|
||||
import { OnboardingText } from "../components/OnboardingText";
|
||||
import StarRating from "../components/StarRating";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/__legacy__/ui/card";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
|
||||
import type { InputValues } from "./types";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Play } from "lucide-react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/RunAgentInputs/RunAgentInputs";
|
||||
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
|
||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
areAllCredentialsSet,
|
||||
computeInitialAgentInputs,
|
||||
getAgentCredentialsInputFields,
|
||||
isRunDisabled,
|
||||
getSchemaDefaultCredentials,
|
||||
sanitizeCredentials,
|
||||
} from "./helpers";
|
||||
import { isRunDisabled } from "./helpers";
|
||||
import { useOnboardingRunStep } from "./useOnboardingRunStep";
|
||||
import { RunAgentHint } from "./components/RunAgentHint";
|
||||
import { SelectedAgentCard } from "./components/SelectedAgentCard";
|
||||
import { AgentOnboardingCredentials } from "./components/AgentOnboardingCredentials/AgentOnboardingCredentials";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
|
||||
export default function Page() {
|
||||
const { state, updateState, setStep } = useOnboarding(
|
||||
undefined,
|
||||
"AGENT_CHOICE",
|
||||
);
|
||||
const [showInput, setShowInput] = useState(false);
|
||||
const [agent, setAgent] = useState<GraphMeta | null>(null);
|
||||
const [storeAgent, setStoreAgent] = useState<StoreAgentDetails | null>(null);
|
||||
const [runningAgent, setRunningAgent] = useState(false);
|
||||
const [inputCredentials, setInputCredentials] = useState<
|
||||
Record<string, CredentialsMetaInput | undefined>
|
||||
>({});
|
||||
const { toast } = useToast();
|
||||
const router = useRouter();
|
||||
const api = useBackendAPI();
|
||||
const {
|
||||
ready,
|
||||
error,
|
||||
showInput,
|
||||
agent,
|
||||
onboarding,
|
||||
storeAgent,
|
||||
runningAgent,
|
||||
credentialsValid,
|
||||
credentialsLoaded,
|
||||
handleSetAgentInput,
|
||||
handleRunAgent,
|
||||
handleNewRun,
|
||||
handleCredentialsChange,
|
||||
handleCredentialsValidationChange,
|
||||
handleCredentialsLoadingChange,
|
||||
} = useOnboardingRunStep();
|
||||
|
||||
useEffect(() => {
|
||||
setStep(5);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!state?.selectedStoreListingVersionId) {
|
||||
return;
|
||||
}
|
||||
api
|
||||
.getStoreAgentByVersionId(state?.selectedStoreListingVersionId)
|
||||
.then((storeAgent) => {
|
||||
setStoreAgent(storeAgent);
|
||||
});
|
||||
api
|
||||
.getGraphMetaByStoreListingVersionID(state.selectedStoreListingVersionId)
|
||||
.then((meta) => {
|
||||
setAgent(meta);
|
||||
const update = computeInitialAgentInputs(
|
||||
meta,
|
||||
(state.agentInput as unknown as InputValues) || null,
|
||||
);
|
||||
updateState({ agentInput: update });
|
||||
});
|
||||
}, [api, setAgent, updateState, state?.selectedStoreListingVersionId]);
|
||||
|
||||
const agentCredentialsInputFields = getAgentCredentialsInputFields(agent);
|
||||
|
||||
const credentialsRequired =
|
||||
Object.keys(agentCredentialsInputFields || {}).length > 0;
|
||||
|
||||
const allCredentialsAreSet = areAllCredentialsSet(
|
||||
agentCredentialsInputFields,
|
||||
inputCredentials,
|
||||
);
|
||||
|
||||
function setAgentInput(key: string, value: string) {
|
||||
updateState({
|
||||
agentInput: {
|
||||
...state?.agentInput,
|
||||
[key]: value,
|
||||
},
|
||||
});
|
||||
if (error) {
|
||||
return <ErrorCard responseError={error} />;
|
||||
}
|
||||
|
||||
async function runAgent() {
|
||||
if (!agent) {
|
||||
return;
|
||||
}
|
||||
setRunningAgent(true);
|
||||
try {
|
||||
const libraryAgent = await api.addMarketplaceAgentToLibrary(
|
||||
storeAgent?.store_listing_version_id || "",
|
||||
);
|
||||
const { id: runID } = await api.executeGraph(
|
||||
libraryAgent.graph_id,
|
||||
libraryAgent.graph_version,
|
||||
state?.agentInput || {},
|
||||
sanitizeCredentials(inputCredentials),
|
||||
);
|
||||
updateState({
|
||||
onboardingAgentExecutionId: runID,
|
||||
agentRuns: (state?.agentRuns || 0) + 1,
|
||||
});
|
||||
router.push("/onboarding/6-congrats");
|
||||
} catch (error) {
|
||||
console.error("Error running agent:", error);
|
||||
toast({
|
||||
title: "Error running agent",
|
||||
description:
|
||||
"There was an error running your agent. Please try again or try choosing a different agent if it still fails.",
|
||||
variant: "destructive",
|
||||
});
|
||||
setRunningAgent(false);
|
||||
}
|
||||
}
|
||||
|
||||
const runYourAgent = (
|
||||
<div className="ml-[104px] w-[481px] pl-5">
|
||||
<div className="flex flex-col">
|
||||
<OnboardingText variant="header">Run your first agent</OnboardingText>
|
||||
<span className="mt-9 text-base font-normal leading-normal text-zinc-600">
|
||||
A 'run' is when your agent starts working on a task
|
||||
</span>
|
||||
<span className="mt-4 text-base font-normal leading-normal text-zinc-600">
|
||||
Click on <b>New Run</b> below to try it out
|
||||
</span>
|
||||
|
||||
<div
|
||||
onClick={() => {
|
||||
setShowInput(true);
|
||||
setStep(6);
|
||||
updateState({
|
||||
completedSteps: [
|
||||
...(state?.completedSteps || []),
|
||||
"AGENT_NEW_RUN",
|
||||
],
|
||||
});
|
||||
}}
|
||||
className={cn(
|
||||
"mt-16 flex h-[68px] w-[330px] items-center justify-center rounded-xl border-2 border-violet-700 bg-neutral-50",
|
||||
"cursor-pointer transition-all duration-200 ease-in-out hover:bg-violet-50",
|
||||
)}
|
||||
>
|
||||
<svg
|
||||
width="38"
|
||||
height="38"
|
||||
viewBox="0 0 32 32"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<g stroke="#6d28d9" strokeWidth="1.2" strokeLinecap="round">
|
||||
<line x1="16" y1="8" x2="16" y2="24" />
|
||||
<line x1="8" y1="16" x2="24" y2="16" />
|
||||
</g>
|
||||
</svg>
|
||||
<span className="ml-3 font-sans text-[19px] font-medium leading-normal text-violet-700">
|
||||
New run
|
||||
</span>
|
||||
</div>
|
||||
if (!ready) {
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<Skeleton className="h-10 w-full" />
|
||||
<Skeleton className="h-10 w-full" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<OnboardingStep dotted>
|
||||
<OnboardingHeader backHref={"/onboarding/4-agent"} transparent />
|
||||
{/* Agent card */}
|
||||
<div className="fixed left-1/4 top-1/2 w-[481px] -translate-x-1/2 -translate-y-1/2">
|
||||
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
|
||||
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
|
||||
SELECTED AGENT
|
||||
</span>
|
||||
{storeAgent ? (
|
||||
<div className="mt-4 flex h-20 rounded-lg bg-violet-50 p-2">
|
||||
{/* Left image */}
|
||||
<SmartImage
|
||||
src={storeAgent?.agent_image[0]}
|
||||
alt="Agent cover"
|
||||
imageContain
|
||||
className="w-[350px] rounded-lg"
|
||||
/>
|
||||
{/* Right content */}
|
||||
<div className="ml-2 flex flex-1 flex-col">
|
||||
<span className="w-[292px] truncate font-sans text-[14px] font-medium leading-normal text-zinc-800">
|
||||
{storeAgent?.agent_name}
|
||||
</span>
|
||||
<span className="mt-[5px] w-[292px] truncate font-sans text-xs font-normal leading-tight text-zinc-600">
|
||||
by {storeAgent?.creator}
|
||||
</span>
|
||||
<div className="mt-auto flex w-[292px] justify-between">
|
||||
<span className="mt-1 truncate font-sans text-xs font-normal leading-tight text-zinc-600">
|
||||
{storeAgent?.runs.toLocaleString("en-US")} runs
|
||||
</span>
|
||||
<StarRating
|
||||
className="font-sans text-xs font-normal leading-tight text-zinc-600"
|
||||
starSize={12}
|
||||
rating={storeAgent?.rating || 0}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="mt-4 flex h-20 animate-pulse rounded-lg bg-gray-300 p-2" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex min-h-[80vh] items-center justify-center">
|
||||
{/* Left side */}
|
||||
<SelectedAgentCard storeAgent={storeAgent} />
|
||||
<div className="w-[481px]" />
|
||||
{/* Right side */}
|
||||
{!showInput ? (
|
||||
runYourAgent
|
||||
<RunAgentHint handleNewRun={handleNewRun} />
|
||||
) : (
|
||||
<div className="ml-[104px] w-[481px] pl-5">
|
||||
<div className="flex flex-col">
|
||||
@@ -232,30 +74,7 @@ export default function Page() {
|
||||
<span className="mt-4 text-base font-normal leading-normal text-zinc-600">
|
||||
When you're done, click <b>Run Agent</b>.
|
||||
</span>
|
||||
{Object.entries(agentCredentialsInputFields || {}).map(
|
||||
([key, inputSubSchema]) => (
|
||||
<div key={key} className="mt-4">
|
||||
<CredentialsInput
|
||||
schema={inputSubSchema}
|
||||
selectedCredentials={
|
||||
inputCredentials[key] ??
|
||||
getSchemaDefaultCredentials(inputSubSchema)
|
||||
}
|
||||
onSelectCredentials={(value) =>
|
||||
setInputCredentials((prev) => ({
|
||||
...prev,
|
||||
[key]: value,
|
||||
}))
|
||||
}
|
||||
siblingInputs={
|
||||
(state?.agentInput || undefined) as
|
||||
| Record<string, any>
|
||||
| undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
),
|
||||
)}
|
||||
|
||||
<Card className="agpt-box mt-4">
|
||||
<CardHeader>
|
||||
<CardTitle className="font-poppins text-lg">Input</CardTitle>
|
||||
@@ -272,13 +91,23 @@ export default function Page() {
|
||||
</label>
|
||||
<RunAgentInputs
|
||||
schema={inputSubSchema}
|
||||
value={state?.agentInput?.[key]}
|
||||
value={onboarding.state?.agentInput?.[key]}
|
||||
placeholder={inputSubSchema.description}
|
||||
onChange={(value) => setAgentInput(key, value)}
|
||||
onChange={(value) => handleSetAgentInput(key, value)}
|
||||
/>
|
||||
</div>
|
||||
),
|
||||
)}
|
||||
<AgentOnboardingCredentials
|
||||
agent={agent}
|
||||
siblingInputs={
|
||||
(onboarding.state?.agentInput as Record<string, any>) ||
|
||||
undefined
|
||||
}
|
||||
onCredentialsChange={handleCredentialsChange}
|
||||
onValidationChange={handleCredentialsValidationChange}
|
||||
onLoadingChange={handleCredentialsLoadingChange}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
<OnboardingButton
|
||||
@@ -289,11 +118,12 @@ export default function Page() {
|
||||
agent,
|
||||
isRunning: runningAgent,
|
||||
agentInputs:
|
||||
(state?.agentInput as unknown as InputValues) || null,
|
||||
credentialsRequired,
|
||||
credentialsSatisfied: allCredentialsAreSet,
|
||||
(onboarding.state?.agentInput as unknown as InputValues) ||
|
||||
null,
|
||||
credentialsValid,
|
||||
credentialsLoaded,
|
||||
})}
|
||||
onClick={runAgent}
|
||||
onClick={handleRunAgent}
|
||||
icon={<Play className="mr-2" size={18} />}
|
||||
>
|
||||
Run agent
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||
import { StoreAgentDetails } from "@/app/api/__generated__/models/storeAgentDetails";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
import { computeInitialAgentInputs } from "./helpers";
|
||||
import { InputValues } from "./types";
|
||||
import {
|
||||
useGetV2GetAgentByVersion,
|
||||
useGetV2GetAgentGraph,
|
||||
} from "@/app/api/__generated__/endpoints/store/store";
|
||||
|
||||
export function useOnboardingRunStep() {
|
||||
const onboarding = useOnboarding(undefined, "AGENT_CHOICE");
|
||||
|
||||
const [showInput, setShowInput] = useState(false);
|
||||
const [agent, setAgent] = useState<GraphMeta | null>(null);
|
||||
const [storeAgent, setStoreAgent] = useState<StoreAgentDetails | null>(null);
|
||||
const [runningAgent, setRunningAgent] = useState(false);
|
||||
|
||||
const [inputCredentials, setInputCredentials] = useState<
|
||||
Record<string, CredentialsMetaInput>
|
||||
>({});
|
||||
|
||||
const [credentialsValid, setCredentialsValid] = useState(true);
|
||||
const [credentialsLoaded, setCredentialsLoaded] = useState(false);
|
||||
|
||||
const { toast } = useToast();
|
||||
const router = useRouter();
|
||||
const api = useBackendAPI();
|
||||
|
||||
const currentAgentVersion =
|
||||
onboarding.state?.selectedStoreListingVersionId ?? "";
|
||||
|
||||
const storeAgentQuery = useGetV2GetAgentByVersion(currentAgentVersion, {
|
||||
query: { enabled: !!currentAgentVersion },
|
||||
});
|
||||
|
||||
const graphMetaQuery = useGetV2GetAgentGraph(currentAgentVersion, {
|
||||
query: { enabled: !!currentAgentVersion },
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
onboarding.setStep(5);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (storeAgentQuery.data && storeAgentQuery.data.status === 200) {
|
||||
setStoreAgent(storeAgentQuery.data.data);
|
||||
}
|
||||
}, [storeAgentQuery.data]);
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
graphMetaQuery.data &&
|
||||
graphMetaQuery.data.status === 200 &&
|
||||
onboarding.state
|
||||
) {
|
||||
const graphMeta = graphMetaQuery.data.data as GraphMeta;
|
||||
|
||||
setAgent(graphMeta);
|
||||
|
||||
const update = computeInitialAgentInputs(
|
||||
graphMeta,
|
||||
(onboarding.state.agentInput as unknown as InputValues) || null,
|
||||
);
|
||||
|
||||
onboarding.updateState({ agentInput: update });
|
||||
}
|
||||
}, [graphMetaQuery.data]);
|
||||
|
||||
function handleNewRun() {
|
||||
if (!onboarding.state) return;
|
||||
|
||||
setShowInput(true);
|
||||
onboarding.setStep(6);
|
||||
onboarding.updateState({
|
||||
completedSteps: [
|
||||
...(onboarding.state.completedSteps || []),
|
||||
"AGENT_NEW_RUN",
|
||||
],
|
||||
});
|
||||
}
|
||||
|
||||
function handleSetAgentInput(key: string, value: string) {
|
||||
if (!onboarding.state) return;
|
||||
|
||||
onboarding.updateState({
|
||||
agentInput: {
|
||||
...onboarding.state.agentInput,
|
||||
[key]: value,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function handleRunAgent() {
|
||||
if (!agent || !storeAgent || !onboarding.state) {
|
||||
toast({
|
||||
title: "Error getting agent",
|
||||
description:
|
||||
"Either the agent is not available or there was an error getting it.",
|
||||
variant: "destructive",
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
setRunningAgent(true);
|
||||
|
||||
try {
|
||||
const libraryAgent = await api.addMarketplaceAgentToLibrary(
|
||||
storeAgent?.store_listing_version_id || "",
|
||||
);
|
||||
|
||||
const { id: runID } = await api.executeGraph(
|
||||
libraryAgent.graph_id,
|
||||
libraryAgent.graph_version,
|
||||
onboarding.state.agentInput || {},
|
||||
inputCredentials,
|
||||
);
|
||||
|
||||
onboarding.updateState({
|
||||
onboardingAgentExecutionId: runID,
|
||||
agentRuns: (onboarding.state.agentRuns || 0) + 1,
|
||||
});
|
||||
|
||||
router.push("/onboarding/6-congrats");
|
||||
} catch (error) {
|
||||
console.error("Error running agent:", error);
|
||||
|
||||
toast({
|
||||
title: "Error running agent",
|
||||
description:
|
||||
"There was an error running your agent. Please try again or try choosing a different agent if it still fails.",
|
||||
variant: "destructive",
|
||||
});
|
||||
|
||||
setRunningAgent(false);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
ready: graphMetaQuery.isSuccess && storeAgentQuery.isSuccess,
|
||||
error: graphMetaQuery.error || storeAgentQuery.error,
|
||||
agent,
|
||||
onboarding,
|
||||
showInput,
|
||||
storeAgent,
|
||||
runningAgent,
|
||||
credentialsValid,
|
||||
credentialsLoaded,
|
||||
handleSetAgentInput,
|
||||
handleRunAgent,
|
||||
handleNewRun,
|
||||
handleCredentialsChange: setInputCredentials,
|
||||
handleCredentialsValidationChange: setCredentialsValid,
|
||||
handleCredentialsLoadingChange: (v: boolean) => setCredentialsLoaded(!v),
|
||||
};
|
||||
}
|
||||
@@ -46,7 +46,7 @@ export default function StarRating({
|
||||
)}
|
||||
>
|
||||
{/* Display numerical rating */}
|
||||
<span className="mr-1 mt-1">{roundedRating}</span>
|
||||
<span className="mr-1 mt-0.5">{roundedRating}</span>
|
||||
|
||||
{/* Display stars */}
|
||||
{stars.map((starType, index) => {
|
||||
|
||||
@@ -19,6 +19,7 @@ import WalletRefill from "./components/WalletRefill";
|
||||
import { OnboardingStep } from "@/lib/autogpt-server-api";
|
||||
import { storage, Key as StorageKey } from "@/services/storage/local-storage";
|
||||
import { WalletIcon } from "@phosphor-icons/react";
|
||||
import { useGetFlag, Flag } from "@/services/feature-flags/use-get-flag";
|
||||
|
||||
export interface Task {
|
||||
id: OnboardingStep;
|
||||
@@ -40,6 +41,7 @@ export interface TaskGroup {
|
||||
|
||||
export default function Wallet() {
|
||||
const { state, updateState } = useOnboarding();
|
||||
const isPaymentEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
|
||||
|
||||
const groups = useMemo<TaskGroup[]>(() => {
|
||||
return [
|
||||
@@ -379,9 +381,7 @@ export default function Wallet() {
|
||||
</div>
|
||||
<ScrollArea className="max-h-[85vh] overflow-y-auto">
|
||||
{/* Top ups */}
|
||||
{process.env.NEXT_PUBLIC_SHOW_BILLING_PAGE === "true" && (
|
||||
<WalletRefill />
|
||||
)}
|
||||
{isPaymentEnabled && <WalletRefill />}
|
||||
{/* Tasks */}
|
||||
<p className="mx-1 my-3 font-sans text-xs font-normal text-zinc-400">
|
||||
Complete the following tasks to earn more credits!
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
"use client";
|
||||
|
||||
import { isServerSide } from "@/lib/utils/is-server-side";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { WaitlistErrorContent } from "@/components/auth/WaitlistErrorContent";
|
||||
import { isWaitlistError } from "@/app/api/auth/utils";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { environment } from "@/services/environment";
|
||||
|
||||
export default function AuthErrorPage() {
|
||||
const [errorType, setErrorType] = useState<string | null>(null);
|
||||
const [errorCode, setErrorCode] = useState<string | null>(null);
|
||||
const [errorDescription, setErrorDescription] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
// This code only runs on the client side
|
||||
if (!isServerSide()) {
|
||||
if (!environment.isServerSide()) {
|
||||
const hash = window.location.hash.substring(1); // Remove the leading '#'
|
||||
const params = new URLSearchParams(hash);
|
||||
|
||||
@@ -23,15 +30,45 @@ export default function AuthErrorPage() {
|
||||
}, []);
|
||||
|
||||
if (!errorType && !errorCode && !errorDescription) {
|
||||
return <div>Loading...</div>;
|
||||
return (
|
||||
<div className="flex h-screen items-center justify-center">
|
||||
<Text variant="body">Loading...</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Check if this is a waitlist/not allowed error
|
||||
const isWaitlistErr = isWaitlistError(errorCode, errorDescription);
|
||||
|
||||
if (isWaitlistErr) {
|
||||
return (
|
||||
<div className="flex h-screen items-center justify-center">
|
||||
<Card className="w-full max-w-md p-8">
|
||||
<WaitlistErrorContent onBackToLogin={() => router.push("/login")} />
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Use ErrorCard for consistent error display
|
||||
const errorMessage = errorDescription
|
||||
? `${errorDescription}. If this error persists, please contact support at contact@agpt.co`
|
||||
: "An authentication error occurred. Please contact support at contact@agpt.co";
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h1>Authentication Error</h1>
|
||||
{errorType && <p>Error Type: {errorType}</p>}
|
||||
{errorCode && <p>Error Code: {errorCode}</p>}
|
||||
{errorDescription && <p>Error Description: {errorDescription}</p>}
|
||||
<div className="flex h-screen items-center justify-center p-4">
|
||||
<div className="w-full max-w-md">
|
||||
<ErrorCard
|
||||
responseError={{
|
||||
message: errorMessage,
|
||||
detail: errorCode
|
||||
? `Error code: ${errorCode}${errorType ? ` (${errorType})` : ""}`
|
||||
: undefined,
|
||||
}}
|
||||
context="authentication"
|
||||
onRetry={() => router.push("/login")}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import { RunGraph } from "./components/RunGraph";
|
||||
|
||||
export const BuilderActions = () => {
|
||||
return (
|
||||
<div className="absolute bottom-4 left-[50%] z-[100] -translate-x-1/2">
|
||||
{/* TODO: Add Agent Output */}
|
||||
<RunGraph />
|
||||
{/* TODO: Add Schedule run button */}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,32 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { PlayIcon } from "lucide-react";
|
||||
import { useRunGraph } from "./useRunGraph";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { StopIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const RunGraph = () => {
|
||||
const { runGraph, isSaving } = useRunGraph();
|
||||
const isGraphRunning = useGraphStore(
|
||||
useShallow((state) => state.isGraphRunning),
|
||||
);
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="large"
|
||||
className={cn(
|
||||
"relative min-w-44 border-none bg-gradient-to-r from-purple-500 to-pink-500 text-lg",
|
||||
)}
|
||||
onClick={() => runGraph()}
|
||||
>
|
||||
{!isGraphRunning && !isSaving ? (
|
||||
<PlayIcon className="mr-1 size-5" />
|
||||
) : (
|
||||
<StopIcon className="mr-1 size-5" />
|
||||
)}
|
||||
{isGraphRunning || isSaving ? "Stop Agent" : "Run Agent"}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,62 @@
|
||||
import { usePostV1ExecuteGraphAgent } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { useNewSaveControl } from "../../../NewControlPanel/NewSaveControl/useNewSaveControl";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
|
||||
import { GraphExecutionMeta } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/use-agent-runs";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
|
||||
export const useRunGraph = () => {
|
||||
const { onSubmit: onSaveGraph, isLoading: isSaving } = useNewSaveControl({
|
||||
showToast: false,
|
||||
});
|
||||
const { toast } = useToast();
|
||||
const setIsGraphRunning = useGraphStore(
|
||||
useShallow((state) => state.setIsGraphRunning),
|
||||
);
|
||||
const [{ flowID, flowVersion }, setQueryStates] = useQueryStates({
|
||||
flowID: parseAsString,
|
||||
flowVersion: parseAsInteger,
|
||||
flowExecutionID: parseAsString,
|
||||
});
|
||||
|
||||
const { mutateAsync: executeGraph } = usePostV1ExecuteGraphAgent({
|
||||
mutation: {
|
||||
onSuccess: (response) => {
|
||||
const { id } = response.data as GraphExecutionMeta;
|
||||
setQueryStates({
|
||||
flowExecutionID: id,
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
setIsGraphRunning(false);
|
||||
|
||||
toast({
|
||||
title: (error.detail as string) ?? "An unexpected error occurred.",
|
||||
description: "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const runGraph = async () => {
|
||||
setIsGraphRunning(true);
|
||||
await onSaveGraph(undefined);
|
||||
|
||||
// Todo : We need to save graph which has inputs and credentials inputs
|
||||
await executeGraph({
|
||||
graphId: flowID ?? "",
|
||||
graphVersion: flowVersion || null,
|
||||
data: {
|
||||
inputs: {},
|
||||
credentials_inputs: {},
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
runGraph,
|
||||
isSaving,
|
||||
};
|
||||
};
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user