mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
41 Commits
copilot/fi
...
seer/fix-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec1ebb4d1d | ||
|
|
18bbd8e572 | ||
|
|
047f011520 | ||
|
|
d11917eb10 | ||
|
|
4663066e65 | ||
|
|
48a0faa611 | ||
|
|
70d00b4104 | ||
|
|
aad0434cb2 | ||
|
|
f33ec1f2ec | ||
|
|
e68b873bcf | ||
|
|
4530e97e59 | ||
|
|
477c261488 | ||
|
|
8ac2228e1e | ||
|
|
91dd9364bb | ||
|
|
f314fbf14f | ||
|
|
a97ff641c3 | ||
|
|
114f604d7b | ||
|
|
3abea1ed96 | ||
|
|
da6e1ad26d | ||
|
|
634fffb967 | ||
|
|
f3ec426c82 | ||
|
|
0b267f573e | ||
|
|
7bd571d9ce | ||
|
|
7a331651ba | ||
|
|
5bc69adc33 | ||
|
|
f4bcc8494f | ||
|
|
4c000086e6 | ||
|
|
9c6cc5b29d | ||
|
|
b34973ca47 | ||
|
|
2bc6a56877 | ||
|
|
87c773d03a | ||
|
|
ebeefc96e8 | ||
|
|
83fe8d5b94 | ||
|
|
50689218ed | ||
|
|
ddff09a8e4 | ||
|
|
0c363a1cea | ||
|
|
e5d870a348 | ||
|
|
3f19cba28f | ||
|
|
a978e91271 | ||
|
|
f283e6c514 | ||
|
|
9fc2101e7e |
5
.github/workflows/platform-backend-ci.yml
vendored
5
.github/workflows/platform-backend-ci.yml
vendored
@@ -37,9 +37,7 @@ jobs:
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: bitnami/redis:6.2
|
||||
env:
|
||||
REDIS_PASSWORD: testpassword
|
||||
image: redis:latest
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
@@ -204,7 +202,6 @@ jobs:
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
|
||||
113
.github/workflows/platform-container-publish.yml
vendored
113
.github/workflows/platform-container-publish.yml
vendored
@@ -1,113 +0,0 @@
|
||||
name: Platform - Container Publishing
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
no_cache:
|
||||
type: boolean
|
||||
description: 'Build from scratch, without using cached layers'
|
||||
default: false
|
||||
registry:
|
||||
type: choice
|
||||
description: 'Container registry to publish to'
|
||||
options:
|
||||
- 'both'
|
||||
- 'ghcr'
|
||||
- 'dockerhub'
|
||||
default: 'both'
|
||||
|
||||
env:
|
||||
GHCR_REGISTRY: ghcr.io
|
||||
GHCR_IMAGE_BASE: ${{ github.repository_owner }}/autogpt-platform
|
||||
DOCKERHUB_IMAGE_BASE: ${{ secrets.DOCKER_USER }}/autogpt-platform
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
component: [backend, frontend]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
if: inputs.registry == 'both' || inputs.registry == 'ghcr' || github.event_name == 'release'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.GHCR_REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
if: (inputs.registry == 'both' || inputs.registry == 'dockerhub' || github.event_name == 'release') && secrets.DOCKER_USER
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_BASE }}-${{ matrix.component }}
|
||||
${{ secrets.DOCKER_USER && format('{0}-{1}', env.DOCKERHUB_IMAGE_BASE, matrix.component) || '' }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
- name: Set build context and dockerfile for backend
|
||||
if: matrix.component == 'backend'
|
||||
run: |
|
||||
echo "BUILD_CONTEXT=." >> $GITHUB_ENV
|
||||
echo "DOCKERFILE=autogpt_platform/backend/Dockerfile" >> $GITHUB_ENV
|
||||
echo "BUILD_TARGET=server" >> $GITHUB_ENV
|
||||
|
||||
- name: Set build context and dockerfile for frontend
|
||||
if: matrix.component == 'frontend'
|
||||
run: |
|
||||
echo "BUILD_CONTEXT=." >> $GITHUB_ENV
|
||||
echo "DOCKERFILE=autogpt_platform/frontend/Dockerfile" >> $GITHUB_ENV
|
||||
echo "BUILD_TARGET=prod" >> $GITHUB_ENV
|
||||
|
||||
- name: Build and push container image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ${{ env.BUILD_CONTEXT }}
|
||||
file: ${{ env.DOCKERFILE }}
|
||||
target: ${{ env.BUILD_TARGET }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: ${{ !inputs.no_cache && 'type=gha' || '' }},scope=platform-${{ matrix.component }}
|
||||
cache-to: type=gha,scope=platform-${{ matrix.component }},mode=max
|
||||
|
||||
- name: Generate build summary
|
||||
run: |
|
||||
echo "## 🐳 Container Build Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Component:** ${{ matrix.component }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Registry:** ${{ inputs.registry || 'both' }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Tags:** ${{ steps.meta.outputs.tags }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "### Images Published:" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
echo "${{ steps.meta.outputs.tags }}" | sed 's/,/\n/g' >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
@@ -1,6 +1,3 @@
|
||||
[pr_reviewer]
|
||||
num_code_suggestions=0
|
||||
|
||||
[pr_code_suggestions]
|
||||
commitable_code_suggestions=false
|
||||
num_code_suggestions=0
|
||||
|
||||
@@ -1,389 +0,0 @@
|
||||
# AutoGPT Platform Container Publishing
|
||||
|
||||
This document describes the container publishing infrastructure and deployment options for the AutoGPT Platform.
|
||||
|
||||
## Published Container Images
|
||||
|
||||
### GitHub Container Registry (GHCR) - Recommended
|
||||
|
||||
- **Backend**: `ghcr.io/significant-gravitas/autogpt-platform-backend`
|
||||
- **Frontend**: `ghcr.io/significant-gravitas/autogpt-platform-frontend`
|
||||
|
||||
### Docker Hub
|
||||
|
||||
- **Backend**: `significantgravitas/autogpt-platform-backend`
|
||||
- **Frontend**: `significantgravitas/autogpt-platform-frontend`
|
||||
|
||||
## Available Tags
|
||||
|
||||
- `latest` - Latest stable release from master branch
|
||||
- `v1.0.0`, `v1.1.0`, etc. - Specific version releases
|
||||
- `main` - Latest development build (use with caution)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Using Docker Compose (Recommended)
|
||||
|
||||
```bash
|
||||
# Clone the repository (or just download the compose file)
|
||||
git clone https://github.com/Significant-Gravitas/AutoGPT.git
|
||||
cd AutoGPT/autogpt_platform
|
||||
|
||||
# Deploy with published images
|
||||
./deploy.sh deploy
|
||||
```
|
||||
|
||||
### Manual Docker Run
|
||||
|
||||
```bash
|
||||
# Start dependencies first
|
||||
docker network create autogpt
|
||||
|
||||
# PostgreSQL
|
||||
docker run -d --name postgres --network autogpt \
|
||||
-e POSTGRES_DB=autogpt \
|
||||
-e POSTGRES_USER=autogpt \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-v postgres_data:/var/lib/postgresql/data \
|
||||
postgres:15
|
||||
|
||||
# Redis
|
||||
docker run -d --name redis --network autogpt \
|
||||
-v redis_data:/data \
|
||||
redis:7-alpine redis-server --requirepass password
|
||||
|
||||
# RabbitMQ
|
||||
docker run -d --name rabbitmq --network autogpt \
|
||||
-e RABBITMQ_DEFAULT_USER=autogpt \
|
||||
-e RABBITMQ_DEFAULT_PASS=password \
|
||||
-p 15672:15672 \
|
||||
rabbitmq:3-management
|
||||
|
||||
# Backend
|
||||
docker run -d --name backend --network autogpt \
|
||||
-p 8000:8000 \
|
||||
-e DATABASE_URL=postgresql://autogpt:password@postgres:5432/autogpt \
|
||||
-e REDIS_HOST=redis \
|
||||
-e RABBITMQ_HOST=rabbitmq \
|
||||
ghcr.io/significant-gravitas/autogpt-platform-backend:latest
|
||||
|
||||
# Frontend
|
||||
docker run -d --name frontend --network autogpt \
|
||||
-p 3000:3000 \
|
||||
-e AGPT_SERVER_URL=http://localhost:8000/api \
|
||||
ghcr.io/significant-gravitas/autogpt-platform-frontend:latest
|
||||
```
|
||||
|
||||
## Deployment Scripts
|
||||
|
||||
### Deploy Script
|
||||
|
||||
The included `deploy.sh` script provides a complete deployment solution:
|
||||
|
||||
```bash
|
||||
# Basic deployment
|
||||
./deploy.sh deploy
|
||||
|
||||
# Deploy specific version
|
||||
./deploy.sh -v v1.0.0 deploy
|
||||
|
||||
# Deploy from Docker Hub
|
||||
./deploy.sh -r docker.io deploy
|
||||
|
||||
# Production deployment
|
||||
./deploy.sh -p production deploy
|
||||
|
||||
# Other operations
|
||||
./deploy.sh start # Start services
|
||||
./deploy.sh stop # Stop services
|
||||
./deploy.sh restart # Restart services
|
||||
./deploy.sh update # Update to latest
|
||||
./deploy.sh backup # Create backup
|
||||
./deploy.sh status # Show status
|
||||
./deploy.sh logs # Show logs
|
||||
./deploy.sh cleanup # Remove everything
|
||||
```
|
||||
|
||||
## Platform-Specific Deployment Guides
|
||||
|
||||
### Unraid
|
||||
|
||||
See [Unraid Deployment Guide](../docs/content/platform/deployment/unraid.md)
|
||||
|
||||
Key features:
|
||||
- Community Applications template
|
||||
- Web UI management
|
||||
- Automatic updates
|
||||
- Built-in backup system
|
||||
|
||||
### Home Assistant Add-on
|
||||
|
||||
See [Home Assistant Add-on Guide](../docs/content/platform/deployment/home-assistant.md)
|
||||
|
||||
Key features:
|
||||
- Native Home Assistant integration
|
||||
- Automation services
|
||||
- Entity monitoring
|
||||
- Backup integration
|
||||
|
||||
### Kubernetes
|
||||
|
||||
See [Kubernetes Deployment Guide](../docs/content/platform/deployment/kubernetes.md)
|
||||
|
||||
Key features:
|
||||
- Helm charts
|
||||
- Horizontal scaling
|
||||
- Health checks
|
||||
- Persistent volumes
|
||||
|
||||
## Container Architecture
|
||||
|
||||
### Backend Container
|
||||
|
||||
- **Base Image**: `debian:13-slim`
|
||||
- **Runtime**: Python 3.13 with Poetry
|
||||
- **Services**: REST API, WebSocket, Executor, Scheduler, Database Manager, Notification
|
||||
- **Ports**: 8000-8007 (depending on service)
|
||||
- **Health Check**: `GET /health`
|
||||
|
||||
### Frontend Container
|
||||
|
||||
- **Base Image**: `node:21-alpine`
|
||||
- **Runtime**: Next.js production build
|
||||
- **Port**: 3000
|
||||
- **Health Check**: HTTP 200 on root path
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
#### Backend
|
||||
```env
|
||||
DATABASE_URL=postgresql://user:pass@host:5432/db
|
||||
REDIS_HOST=redis
|
||||
RABBITMQ_HOST=rabbitmq
|
||||
JWT_SECRET=your-secret-key
|
||||
```
|
||||
|
||||
#### Frontend
|
||||
```env
|
||||
AGPT_SERVER_URL=http://backend:8000/api
|
||||
SUPABASE_URL=http://auth:8000
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
```env
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_DEBUG=false
|
||||
|
||||
# Performance
|
||||
REDIS_PASSWORD=your-redis-password
|
||||
RABBITMQ_PASSWORD=your-rabbitmq-password
|
||||
|
||||
# Security
|
||||
CORS_ORIGINS=http://localhost:3000
|
||||
```
|
||||
|
||||
## CI/CD Pipeline
|
||||
|
||||
### GitHub Actions Workflow
|
||||
|
||||
The publishing workflow (`.github/workflows/platform-container-publish.yml`) automatically:
|
||||
|
||||
1. **Triggers** on releases and manual dispatch
|
||||
2. **Builds** both backend and frontend containers
|
||||
3. **Tests** container functionality
|
||||
4. **Publishes** to both GHCR and Docker Hub
|
||||
5. **Tags** with version and latest
|
||||
|
||||
### Manual Publishing
|
||||
|
||||
```bash
|
||||
# Build and tag locally
|
||||
docker build -t ghcr.io/significant-gravitas/autogpt-platform-backend:latest \
|
||||
-f autogpt_platform/backend/Dockerfile \
|
||||
--target server .
|
||||
|
||||
docker build -t ghcr.io/significant-gravitas/autogpt-platform-frontend:latest \
|
||||
-f autogpt_platform/frontend/Dockerfile \
|
||||
--target prod .
|
||||
|
||||
# Push to registry
|
||||
docker push ghcr.io/significant-gravitas/autogpt-platform-backend:latest
|
||||
docker push ghcr.io/significant-gravitas/autogpt-platform-frontend:latest
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Container Security
|
||||
|
||||
1. **Non-root users** - Containers run as non-root
|
||||
2. **Minimal base images** - Using slim/alpine images
|
||||
3. **No secrets in images** - All secrets via environment variables
|
||||
4. **Read-only filesystem** - Where possible
|
||||
5. **Resource limits** - CPU and memory limits set
|
||||
|
||||
### Deployment Security
|
||||
|
||||
1. **Network isolation** - Use dedicated networks
|
||||
2. **TLS encryption** - Enable HTTPS in production
|
||||
3. **Secret management** - Use Docker secrets or external secret stores
|
||||
4. **Regular updates** - Keep images updated
|
||||
5. **Vulnerability scanning** - Regular security scans
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Health Checks
|
||||
|
||||
All containers include health checks:
|
||||
|
||||
```bash
|
||||
# Check container health
|
||||
docker inspect --format='{{.State.Health.Status}}' container_name
|
||||
|
||||
# Manual health check
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
### Metrics
|
||||
|
||||
The backend exposes Prometheus metrics at `/metrics`:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/metrics
|
||||
```
|
||||
|
||||
### Logging
|
||||
|
||||
Containers log to stdout/stderr for easy aggregation:
|
||||
|
||||
```bash
|
||||
# View logs
|
||||
docker logs container_name
|
||||
|
||||
# Follow logs
|
||||
docker logs -f container_name
|
||||
|
||||
# Aggregate logs
|
||||
docker compose logs -f
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Container won't start**
|
||||
```bash
|
||||
# Check logs
|
||||
docker logs container_name
|
||||
|
||||
# Check environment
|
||||
docker exec container_name env
|
||||
```
|
||||
|
||||
2. **Database connection failed**
|
||||
```bash
|
||||
# Test connectivity
|
||||
docker exec backend ping postgres
|
||||
|
||||
# Check database status
|
||||
docker exec postgres pg_isready
|
||||
```
|
||||
|
||||
3. **Port conflicts**
|
||||
```bash
|
||||
# Check port usage
|
||||
ss -tuln | grep :3000
|
||||
|
||||
# Use different ports
|
||||
docker run -p 3001:3000 ...
|
||||
```
|
||||
|
||||
### Debug Mode
|
||||
|
||||
Enable debug mode for detailed logging:
|
||||
|
||||
```env
|
||||
LOG_LEVEL=DEBUG
|
||||
ENABLE_DEBUG=true
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Resource Limits
|
||||
|
||||
```yaml
|
||||
# Docker Compose
|
||||
services:
|
||||
backend:
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 2G
|
||||
cpus: '1.0'
|
||||
reservations:
|
||||
memory: 1G
|
||||
cpus: '0.5'
|
||||
```
|
||||
|
||||
### Scaling
|
||||
|
||||
```bash
|
||||
# Scale backend services
|
||||
docker compose up -d --scale backend=3
|
||||
|
||||
# Or use Docker Swarm
|
||||
docker service scale backend=3
|
||||
```
|
||||
|
||||
## Backup and Recovery
|
||||
|
||||
### Data Backup
|
||||
|
||||
```bash
|
||||
# Database backup
|
||||
docker exec postgres pg_dump -U autogpt autogpt > backup.sql
|
||||
|
||||
# Volume backup
|
||||
docker run --rm -v postgres_data:/data -v $(pwd):/backup \
|
||||
alpine tar czf /backup/postgres_backup.tar.gz /data
|
||||
```
|
||||
|
||||
### Restore
|
||||
|
||||
```bash
|
||||
# Database restore
|
||||
docker exec -i postgres psql -U autogpt autogpt < backup.sql
|
||||
|
||||
# Volume restore
|
||||
docker run --rm -v postgres_data:/data -v $(pwd):/backup \
|
||||
alpine tar xzf /backup/postgres_backup.tar.gz -C /
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
- **Documentation**: [Platform Docs](../docs/content/platform/)
|
||||
- **Issues**: [GitHub Issues](https://github.com/Significant-Gravitas/AutoGPT/issues)
|
||||
- **Discord**: [AutoGPT Community](https://discord.gg/autogpt)
|
||||
- **Docker Hub**: [Container Registry](https://hub.docker.com/r/significantgravitas/)
|
||||
|
||||
## Contributing
|
||||
|
||||
To contribute to the container infrastructure:
|
||||
|
||||
1. **Test locally** with `docker build` and `docker run`
|
||||
2. **Update documentation** if making changes
|
||||
3. **Test deployment scripts** on your platform
|
||||
4. **Submit PR** with clear description of changes
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [ ] ARM64 support for Apple Silicon
|
||||
- [ ] Helm charts for Kubernetes
|
||||
- [ ] Official Unraid template
|
||||
- [ ] Home Assistant Add-on store submission
|
||||
- [ ] Multi-stage builds optimization
|
||||
- [ ] Security scanning integration
|
||||
- [ ] Performance benchmarking
|
||||
@@ -2,38 +2,16 @@
|
||||
|
||||
Welcome to the AutoGPT Platform - a powerful system for creating and running AI agents to solve business problems. This platform enables you to harness the power of artificial intelligence to automate tasks, analyze data, and generate insights for your organization.
|
||||
|
||||
## Deployment Options
|
||||
|
||||
### Quick Deploy with Published Containers (Recommended)
|
||||
|
||||
The fastest way to get started is using our pre-built containers:
|
||||
|
||||
```bash
|
||||
# Download and run with published images
|
||||
curl -fsSL https://raw.githubusercontent.com/Significant-Gravitas/AutoGPT/master/autogpt_platform/deploy.sh -o deploy.sh
|
||||
chmod +x deploy.sh
|
||||
./deploy.sh deploy
|
||||
```
|
||||
|
||||
Access the platform at http://localhost:3000 after deployment completes.
|
||||
|
||||
### Platform-Specific Deployments
|
||||
|
||||
- **Unraid**: [Deployment Guide](../docs/content/platform/deployment/unraid.md)
|
||||
- **Home Assistant**: [Add-on Guide](../docs/content/platform/deployment/home-assistant.md)
|
||||
- **Kubernetes**: [K8s Deployment](../docs/content/platform/deployment/kubernetes.md)
|
||||
- **General Containers**: [Container Guide](../docs/content/platform/container-deployment.md)
|
||||
|
||||
## Development Setup
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker
|
||||
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
|
||||
|
||||
### Running from Source
|
||||
### Running the System
|
||||
|
||||
To run the AutoGPT Platform from source for development:
|
||||
To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
1. Clone this repository to your local machine and navigate to the `autogpt_platform` directory within the repository:
|
||||
|
||||
@@ -179,28 +157,3 @@ If you need to update the API client after making changes to the backend API:
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
## Container Deployment
|
||||
|
||||
For production deployments and specific platforms, see our container deployment guides:
|
||||
|
||||
- **[Container Deployment Overview](CONTAINERS.md)** - Complete guide to using published containers
|
||||
- **[Deployment Script](deploy.sh)** - Automated deployment and management tool
|
||||
- **[Published Images](docker-compose.published.yml)** - Docker Compose for published containers
|
||||
|
||||
### Published Container Images
|
||||
|
||||
- **Backend**: `ghcr.io/significant-gravitas/autogpt-platform-backend:latest`
|
||||
- **Frontend**: `ghcr.io/significant-gravitas/autogpt-platform-frontend:latest`
|
||||
|
||||
### Quick Production Deployment
|
||||
|
||||
```bash
|
||||
# Deploy with published containers
|
||||
./deploy.sh deploy
|
||||
|
||||
# Or use the published compose file directly
|
||||
docker compose -f docker-compose.published.yml up -d
|
||||
```
|
||||
|
||||
For detailed deployment instructions, troubleshooting, and platform-specific guides, see the [Container Documentation](CONTAINERS.md).
|
||||
|
||||
@@ -10,7 +10,7 @@ from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
@@ -20,7 +20,9 @@ def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
async def requires_admin_user(
|
||||
jwt_payload: dict = fastapi.Security(get_jwt_payload),
|
||||
) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
@@ -30,7 +32,7 @@ def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
async def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestAuthDependencies:
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
@@ -53,12 +53,12 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
@@ -69,28 +69,28 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_user(jwt_payload)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_user_missing_sub(self):
|
||||
async def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_requires_user_empty_sub(self):
|
||||
async def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_user(jwt_payload)
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
@@ -101,51 +101,51 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
user = await requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
requires_admin_user(jwt_payload)
|
||||
await requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
requires_admin_user(jwt_payload)
|
||||
await requires_admin_user(jwt_payload)
|
||||
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
user_id = await get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
def test_get_user_id_missing_sub(self):
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
await get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
def test_get_user_id_none_sub(self):
|
||||
async def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_user_id(jwt_payload)
|
||||
await get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestAuthDependenciesIntegration:
|
||||
|
||||
return _create_token
|
||||
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -184,7 +184,7 @@ class TestAuthDependenciesIntegration:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -203,7 +203,7 @@ class TestAuthDependenciesIntegration:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestAuthDependenciesIntegration:
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
def test_dependency_with_complex_payload(self):
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -256,14 +256,14 @@ class TestAuthDependenciesEdgeCases:
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = requires_user(complex_payload)
|
||||
user = await requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = requires_admin_user(complex_payload)
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
@@ -272,11 +272,11 @@ class TestAuthDependenciesEdgeCases:
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = requires_user(unicode_payload)
|
||||
user = await requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
def test_dependency_with_null_values(self):
|
||||
async def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -286,18 +286,18 @@ class TestAuthDependenciesEdgeCases:
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = requires_user(null_payload)
|
||||
user = await requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
def test_concurrent_requests_isolation(self):
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
user1 = await requires_user(payload1)
|
||||
user2 = await requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
@@ -314,7 +314,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
def test_dependency_error_cases(
|
||||
async def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
@@ -325,7 +325,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
def test_dependency_valid_user(self):
|
||||
async def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
@@ -16,7 +16,7 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_payload(
|
||||
async def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -116,32 +116,32 @@ def test_parse_jwt_token_missing_audience():
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
async def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
result = await jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
async def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
await jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
async def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
await jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -13,8 +15,8 @@ class RateLimitSettings(BaseSettings):
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: str = Field(
|
||||
default="password",
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ class RateLimiter:
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
|
||||
@@ -1,90 +1,68 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Any,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
pass
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Convert args and kwargs into a hashable cache key.
|
||||
|
||||
Handles unhashable types like dict, list, set by converting them to
|
||||
their sorted string representations.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
pass
|
||||
def make_hashable(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a hashable representation."""
|
||||
if isinstance(obj, dict):
|
||||
# Sort dict items to ensure consistent ordering
|
||||
return (
|
||||
"__dict__",
|
||||
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
|
||||
)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return ("__list__", tuple(make_hashable(item) for item in obj))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
|
||||
else:
|
||||
# For basic hashable types (str, int, bool, None, etc.)
|
||||
try:
|
||||
hash(obj)
|
||||
return obj
|
||||
except TypeError:
|
||||
# Fallback: convert to string representation
|
||||
return ("__str__", str(obj))
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT")
|
||||
|
||||
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
hashable_args = tuple(make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
|
||||
return (hashable_args, hashable_kwargs)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
class CachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for cached functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
@@ -94,101 +72,180 @@ class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
|
||||
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
|
||||
return False
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
|
||||
Args:
|
||||
func: The function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
@cache() # Default: maxsize=128, no TTL
|
||||
def expensive_sync_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
@cache() # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
@cache(maxsize=1000, ttl_seconds=300) # Custom maxsize and TTL
|
||||
def another_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
|
||||
def decorator(target_func):
|
||||
# Cache storage and per-event-loop locks
|
||||
cache_storage = {}
|
||||
_event_loop_locks = {} # Maps event loop to its asyncio.Lock
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, cache_storage[key])
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, result)
|
||||
def _get_cache_lock():
|
||||
"""Get or create an asyncio.Lock for the current event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop, use None as default key
|
||||
loop = None
|
||||
|
||||
if loop not in _event_loop_locks:
|
||||
return _event_loop_locks.setdefault(loop, asyncio.Lock())
|
||||
return _event_loop_locks[loop]
|
||||
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with _get_cache_lock():
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
wrapper = async_wrapper
|
||||
|
||||
else:
|
||||
# Sync function with threading.Lock
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Add cache management methods
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
@@ -199,68 +256,84 @@ def async_ttl_cache(
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
setattr(wrapper, "cache_delete", cache_delete)
|
||||
|
||||
return cast(AsyncCachedFunction[P, R], wrapper)
|
||||
return cast(CachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
pass
|
||||
|
||||
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]] | None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> (
|
||||
AsyncCachedFunction[P, R]
|
||||
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
|
||||
):
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
Each thread gets its own cache, which is useful for request-scoped caching
|
||||
in web applications where you want to cache within a single request but
|
||||
not across requests.
|
||||
|
||||
Args:
|
||||
func: The async function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
func: The function to cache
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
Decorated function with thread-local caching
|
||||
|
||||
Example:
|
||||
# Without parentheses (uses default maxsize=128)
|
||||
@async_cache
|
||||
async def get_data(param: str) -> dict:
|
||||
@thread_cached
|
||||
def expensive_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# With parentheses and custom maxsize
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
@thread_cached # Works with async too
|
||||
async def expensive_async_operation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
if func is None:
|
||||
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = await func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
# Called without parentheses @async_cache
|
||||
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
return decorator(func)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
"""Clear thread-local cache for a function."""
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
@@ -16,12 +16,7 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
from autogpt_libs.utils.cache import cached, clear_thread_cache, thread_cached
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -330,102 +325,202 @@ class TestThreadCached:
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
class TestCache:
|
||||
"""Tests for the unified @cache decorator (works for both sync and async)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
def test_basic_sync_caching(self):
|
||||
"""Test basic sync caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
@cached()
|
||||
def expensive_sync_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = expensive_sync_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = expensive_sync_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = expensive_sync_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_async_caching(self):
|
||||
"""Test basic async caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
result1 = await expensive_async_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
result2 = await expensive_async_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
assert call_count == 1
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
result3 = await expensive_async_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
def test_sync_thundering_herd_protection(self):
|
||||
"""Test that concurrent sync calls don't cause thundering herd."""
|
||||
call_count = 0
|
||||
results = []
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
@cached()
|
||||
def slow_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
time.sleep(0.1) # Simulate expensive operation
|
||||
return x * x
|
||||
|
||||
def worker():
|
||||
result = slow_function(5)
|
||||
results.append(result)
|
||||
|
||||
# Launch multiple concurrent threads
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(worker) for _ in range(5)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 25 for result in results)
|
||||
# Only one thread should have executed the expensive operation
|
||||
assert call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_thundering_herd_protection(self):
|
||||
"""Test that concurrent async calls don't cause thundering herd."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
async def slow_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.1) # Simulate expensive operation
|
||||
return x * x
|
||||
|
||||
# Launch concurrent coroutines
|
||||
tasks = [slow_async_function(7) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 49 for result in results)
|
||||
# Only one coroutine should have executed the expensive operation
|
||||
assert call_count == 1
|
||||
|
||||
def test_ttl_functionality(self):
|
||||
"""Test TTL functionality with sync function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
def ttl_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 3
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
result1 = ttl_function(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
result2 = ttl_function(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
time.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = ttl_function(3)
|
||||
assert result3 == 9
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_ttl_functionality(self):
|
||||
"""Test TTL functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@cached(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def async_ttl_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await async_ttl_function(3)
|
||||
assert result1 == 12
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await async_ttl_function(3)
|
||||
assert result2 == 12
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
result3 = await async_ttl_function(3)
|
||||
assert result3 == 12
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
@cached(maxsize=10, ttl_seconds=60)
|
||||
def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
assert info["maxsize"] == 10
|
||||
assert info["ttl_seconds"] == 60
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
@cached()
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
result1 = clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
result2 = clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
@@ -433,273 +528,149 @@ class TestAsyncTTLCache:
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
result3 = clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
async def test_async_cache_clear(self):
|
||||
"""Test cache clearing functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
@cached()
|
||||
async def async_clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 5
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
# First call
|
||||
result1 = await async_clearable_function(2)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
# Second call - should use cache
|
||||
result2 = await async_clearable_function(2)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
# Clear cache
|
||||
async_clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await async_clearable_function(2)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_returns_results_not_coroutines(self):
|
||||
"""Test that cached async functions return actual results, not coroutines."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
async def async_result_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return f"result_{x}"
|
||||
|
||||
# First call
|
||||
result1 = await async_result_function(1)
|
||||
assert result1 == "result_1"
|
||||
assert isinstance(result1, str) # Should be string, not coroutine
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should return cached result (string), not coroutine
|
||||
result2 = await async_result_function(1)
|
||||
assert result2 == "result_1"
|
||||
assert isinstance(result2, str) # Should be string, not coroutine
|
||||
assert call_count == 1 # Function should not be called again
|
||||
|
||||
# Verify results are identical
|
||||
assert result1 is result2 # Should be same cached object
|
||||
|
||||
def test_cache_delete(self):
|
||||
"""Test selective cache deletion functionality."""
|
||||
call_count = 0
|
||||
|
||||
@cached()
|
||||
def deletable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 6
|
||||
|
||||
# First call for x=1
|
||||
result1 = deletable_function(1)
|
||||
assert result1 == 6
|
||||
assert call_count == 1
|
||||
|
||||
# First call for x=2
|
||||
result2 = deletable_function(2)
|
||||
assert result2 == 12
|
||||
assert call_count == 2
|
||||
|
||||
# Second calls - should use cache
|
||||
assert deletable_function(1) == 6
|
||||
assert deletable_function(2) == 12
|
||||
assert call_count == 2
|
||||
|
||||
# Delete specific entry for x=1
|
||||
was_deleted = deletable_function.cache_delete(1)
|
||||
assert was_deleted is True
|
||||
|
||||
# Call with x=1 should execute function again
|
||||
result3 = deletable_function(1)
|
||||
assert result3 == 6
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
# Call with x=2 should still use cache
|
||||
assert deletable_function(2) == 12
|
||||
assert call_count == 3
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
# Try to delete non-existent entry
|
||||
was_deleted = deletable_function.cache_delete(99)
|
||||
assert was_deleted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
async def test_async_cache_delete(self):
|
||||
"""Test selective cache deletion functionality with async function."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
@cached()
|
||||
async def async_deletable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 7
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
# First call for x=1
|
||||
result1 = await async_deletable_function(1)
|
||||
assert result1 == 7
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
# First call for x=2
|
||||
result2 = await async_deletable_function(2)
|
||||
assert result2 == 14
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
# Second calls - should use cache
|
||||
assert await async_deletable_function(1) == 7
|
||||
assert await async_deletable_function(2) == 14
|
||||
assert call_count == 2
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
# Delete specific entry for x=1
|
||||
was_deleted = async_deletable_function.cache_delete(1)
|
||||
assert was_deleted is True
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
# Call with x=1 should execute function again
|
||||
result3 = await async_deletable_function(1)
|
||||
assert result3 == 7
|
||||
assert call_count == 3
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
# Call with x=2 should still use cache
|
||||
assert await async_deletable_function(2) == 14
|
||||
assert call_count == 3
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
# Try to delete non-existent entry
|
||||
was_deleted = async_deletable_function.cache_delete(99)
|
||||
assert was_deleted is False
|
||||
|
||||
@@ -21,7 +21,7 @@ PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
# REDIS_PASSWORD=
|
||||
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
@@ -66,6 +66,11 @@ NVIDIA_API_KEY=
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Notion OAuth App server credentials - https://developers.notion.com/docs/authorization
|
||||
# Configure a public integration
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
|
||||
10
autogpt_platform/backend/.gitignore
vendored
10
autogpt_platform/backend/.gitignore
vendored
@@ -9,4 +9,12 @@ secrets/*
|
||||
!secrets/.gitkeep
|
||||
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
*.ign.*
|
||||
|
||||
# Load test results and reports
|
||||
load-tests/*_RESULTS.md
|
||||
load-tests/*_REPORT.md
|
||||
load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
|
||||
@@ -9,8 +9,15 @@ WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Update package list and install Python and build dependencies
|
||||
# Install Node.js repository key and setup
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y curl ca-certificates gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list
|
||||
|
||||
# Update package list and install Python, Node.js, and build dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
@@ -20,7 +27,9 @@ RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client
|
||||
postgresql-client \
|
||||
nodejs \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
@@ -54,13 +63,18 @@ ENV PATH=/opt/poetry/bin:$PATH
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Prisma binaries
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
@@ -6,6 +5,8 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -15,7 +16,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@functools.cache
|
||||
@cached()
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
214
autogpt_platform/backend/backend/blocks/ai_condition.py
Normal file
214
autogpt_platform/backend/backend/blocks/ai_condition.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
)
|
||||
from backend.data.block import BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
An AI-powered condition block that uses natural language to evaluate conditions.
|
||||
|
||||
This block allows users to define conditions in plain English (e.g., "the input is an email address",
|
||||
"the input is a city in the USA") and uses AI to determine if the input satisfies the condition.
|
||||
It provides the same yes/no data pass-through functionality as the standard ConditionBlock.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
input_value: Any = SchemaField(
|
||||
description="The input value to evaluate with the AI condition",
|
||||
placeholder="Enter the value to be evaluated (text, number, or any data)",
|
||||
)
|
||||
condition: str = SchemaField(
|
||||
description="A plaintext English description of the condition to evaluate",
|
||||
placeholder="E.g., 'the input is the body of an email', 'the input is a City in the USA', 'the input is an error or a refusal'",
|
||||
)
|
||||
yes_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is true. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
no_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is false. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the AI condition evaluation (True or False)"
|
||||
)
|
||||
yes_output: Any = SchemaField(
|
||||
description="The output value if the condition is true"
|
||||
)
|
||||
no_output: Any = SchemaField(
|
||||
description="The output value if the condition is false"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the AI evaluation is uncertain or fails"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553ec5b8-6c45-4299-8d75-b394d05f72ff",
|
||||
input_schema=AIConditionBlock.Input,
|
||||
output_schema=AIConditionBlock.Output,
|
||||
description="Uses AI to evaluate natural language conditions and provide conditional outputs",
|
||||
categories={BlockCategory.AI, BlockCategory.LOGIC},
|
||||
test_input={
|
||||
"input_value": "john@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", True),
|
||||
("yes_output", "Valid email"),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response="true",
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def llm_call(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list,
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Wrapper method for llm_call to enable mocking in tests."""
|
||||
return await llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
force_json_output=False,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Evaluate the AI condition and return appropriate outputs.
|
||||
"""
|
||||
# Prepare the yes and no values, using input_value as default
|
||||
yes_value = (
|
||||
input_data.yes_value
|
||||
if input_data.yes_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
no_value = (
|
||||
input_data.no_value
|
||||
if input_data.no_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
|
||||
# Convert input_value to string for AI evaluation
|
||||
input_str = str(input_data.input_value)
|
||||
|
||||
# Create the prompt for AI evaluation
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an AI assistant that evaluates conditions based on input data. "
|
||||
"You must respond with only 'true' or 'false' (lowercase) to indicate whether "
|
||||
"the given condition is met by the input value. Be accurate and consider the "
|
||||
"context and meaning of both the input and the condition."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Input value: {input_str}\n"
|
||||
f"Condition to evaluate: {input_data.condition}\n\n"
|
||||
f"Does the input value satisfy the condition? Respond with only 'true' or 'false'."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
@@ -1,8 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from pydantic import SecretStr
|
||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||
from pydantic import BaseModel, JsonValue, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
@@ -36,6 +37,37 @@ class ProgrammingLanguage(Enum):
|
||||
JAVA = "java"
|
||||
|
||||
|
||||
class CodeExecutionResult(BaseModel):
|
||||
"""
|
||||
*Pydantic model mirroring `e2b_code_interpreter.Result`*
|
||||
|
||||
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
|
||||
The result is similar to the structure returned by ipython kernel: https://ipython.readthedocs.io/en/stable/development/execution.html#execution-semantics
|
||||
|
||||
The result can contain multiple types of data, such as text, images, plots, etc. Each type of data is represented
|
||||
as a string, and the result can contain multiple types of data. The display calls don't have to have text representation,
|
||||
for the actual result the representation is always present for the result, the other representations are always optional.
|
||||
"""
|
||||
|
||||
class Chart(BaseModel, E2BExecutionResultChart):
|
||||
pass
|
||||
|
||||
text: Optional[str] = None
|
||||
html: Optional[str] = None
|
||||
markdown: Optional[str] = None
|
||||
svg: Optional[str] = None
|
||||
png: Optional[str] = None
|
||||
jpeg: Optional[str] = None
|
||||
pdf: Optional[str] = None
|
||||
latex: Optional[str] = None
|
||||
json: Optional[JsonValue] = None # type: ignore (reportIncompatibleMethodOverride)
|
||||
javascript: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
chart: Optional[Chart] = None
|
||||
extra: Optional[dict] = None
|
||||
"""Extra data that can be included. Not part of the standard types."""
|
||||
|
||||
|
||||
class CodeExecutionBlock(Block):
|
||||
# TODO : Add support to upload and download files
|
||||
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
|
||||
@@ -87,7 +119,16 @@ class CodeExecutionBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
main_result: CodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -111,14 +152,16 @@ class CodeExecutionBlock(Block):
|
||||
"template_id": "",
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -158,11 +201,12 @@ class CodeExecutionBlock(Block):
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
return results, text_output, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -171,7 +215,7 @@ class CodeExecutionBlock(Block):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
results, text_output, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
@@ -180,8 +224,21 @@ class CodeExecutionBlock(Block):
|
||||
input_data.template_id,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield "response", response
|
||||
# Determine result object shape & filter out empty formats
|
||||
results = [
|
||||
{
|
||||
f: r[f]
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if getattr(r, f, None) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
yield "results", results
|
||||
for r in results:
|
||||
if r.pop("is_main_result", False):
|
||||
yield "main_result", r
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
@@ -240,7 +297,10 @@ class InstantiationBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
sandbox_id: str = SchemaField(description="ID of the sandbox instance")
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
response: str = SchemaField(
|
||||
title="Text Result",
|
||||
description="Text result (if any) of the setup code execution",
|
||||
)
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -270,10 +330,10 @@ class InstantiationBlock(Block):
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda setup_code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"sandbox_id",
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
"sandbox_id", # sandbox_id
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -282,7 +342,7 @@ class InstantiationBlock(Block):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sandbox_id, response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
sandbox_id, text_output, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.setup_code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
@@ -294,8 +354,9 @@ class InstantiationBlock(Block):
|
||||
yield "sandbox_id", sandbox_id
|
||||
else:
|
||||
yield "error", "Sandbox ID not found"
|
||||
if response:
|
||||
yield "response", response
|
||||
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
@@ -338,11 +399,11 @@ class InstantiationBlock(Block):
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return sandbox.sandbox_id, response, stdout_logs, stderr_logs
|
||||
return sandbox.sandbox_id, text_output, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -375,7 +436,16 @@ class StepExecutionBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
main_result: CodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
@@ -397,14 +467,16 @@ class StepExecutionBlock(Block):
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_step_code": lambda sandbox_id, step_code, language, api_key: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -427,11 +499,12 @@ class StepExecutionBlock(Block):
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
return results, text_output, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -440,15 +513,30 @@ class StepExecutionBlock(Block):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = await self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
credentials.api_key.get_secret_value(),
|
||||
results, text_output, stdout_logs, stderr_logs = (
|
||||
await self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
)
|
||||
|
||||
if response:
|
||||
yield "response", response
|
||||
# Determine result object shape & filter out empty formats
|
||||
results = [
|
||||
{
|
||||
f: r[f]
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if getattr(r, f, None) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
yield "results", results
|
||||
for r in results:
|
||||
if r.pop("is_main_result", False):
|
||||
yield "main_result", r
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
|
||||
@@ -113,6 +113,7 @@ class DataForSeoClient:
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
depth: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get related keywords from DataForSEO Labs.
|
||||
@@ -125,6 +126,7 @@ class DataForSeoClient:
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
depth: Keyword search depth (0-4), controls number of returned keywords
|
||||
|
||||
Returns:
|
||||
API response with related keywords
|
||||
@@ -148,6 +150,8 @@ class DataForSeoClient:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
if depth is not None:
|
||||
task_data["depth"] = depth
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
|
||||
@@ -90,6 +90,7 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -161,43 +162,52 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch keyword suggestions: {str(e)}"
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
|
||||
@@ -78,6 +78,12 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
depth: int = SchemaField(
|
||||
description="Keyword search depth (0-4). Controls the number of returned keywords: 0=1 keyword, 1=~8 keywords, 2=~72 keywords, 3=~584 keywords, 4=~4680 keywords",
|
||||
default=1,
|
||||
ge=0,
|
||||
le=4,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
@@ -92,6 +98,7 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -154,6 +161,7 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
depth=input_data.depth,
|
||||
)
|
||||
|
||||
async def run(
|
||||
@@ -164,50 +172,60 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get(
|
||||
"competition"
|
||||
),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get(
|
||||
"keyword_properties", {}
|
||||
).get("keyword_difficulty"),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch related keywords: {str(e)}"
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
|
||||
@@ -10,7 +10,6 @@ from backend.util.settings import Config
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.type import LongTextType, MediaFileType, ShortTextType
|
||||
|
||||
formatter = TextFormatter()
|
||||
config = Config()
|
||||
|
||||
|
||||
@@ -132,6 +131,11 @@ class AgentOutputBlock(Block):
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
escape_html: bool = SchemaField(
|
||||
default=False,
|
||||
advanced=True,
|
||||
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
@@ -193,6 +197,7 @@ class AgentOutputBlock(Block):
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
formatter = TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
@@ -549,6 +554,89 @@ class AgentToggleInputBlock(AgentInputBlock):
|
||||
)
|
||||
|
||||
|
||||
class AgentTableInputBlock(AgentInputBlock):
|
||||
"""
|
||||
This block allows users to input data in a table format.
|
||||
|
||||
Configure the table columns at build time, then users can input
|
||||
rows of data at runtime. Each row is output as a dictionary
|
||||
with column names as keys.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[list[dict[str, Any]]] = SchemaField(
|
||||
description="The table data as a list of dictionaries.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
column_headers: list[str] = SchemaField(
|
||||
description="Column headers for the table.",
|
||||
default_factory=lambda: ["Column 1", "Column 2", "Column 3"],
|
||||
advanced=False,
|
||||
title="Column Headers",
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
"""Generate schema for the value field with table format."""
|
||||
schema = super().generate_schema()
|
||||
schema["type"] = "array"
|
||||
schema["format"] = "table"
|
||||
schema["items"] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
header: {"type": "string"}
|
||||
for header in (
|
||||
self.column_headers or ["Column 1", "Column 2", "Column 3"]
|
||||
)
|
||||
},
|
||||
}
|
||||
if self.value is not None:
|
||||
schema["default"] = self.value
|
||||
return schema
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: list[dict[str, Any]] = SchemaField(
|
||||
description="The table data as a list of dictionaries with headers as keys."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5603b273-f41e-4020-af7d-fbc9c6a8d928",
|
||||
description="Block for table data input with customizable headers.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentTableInputBlock.Input,
|
||||
output_schema=AgentTableInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"name": "test_table",
|
||||
"column_headers": ["Name", "Age", "City"],
|
||||
"value": [
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
],
|
||||
"description": "Example table input",
|
||||
}
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
[
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Yields the table data as a list of dictionaries.
|
||||
"""
|
||||
# Pass through the value, defaulting to empty list if None
|
||||
yield "result", input_data.value if input_data.value is not None else []
|
||||
|
||||
|
||||
IO_BLOCK_IDs = [
|
||||
AgentInputBlock().id,
|
||||
AgentOutputBlock().id,
|
||||
@@ -560,4 +648,5 @@ IO_BLOCK_IDs = [
|
||||
AgentFileInputBlock().id,
|
||||
AgentDropdownInputBlock().id,
|
||||
AgentToggleInputBlock().id,
|
||||
AgentTableInputBlock().id,
|
||||
]
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# This file contains a lot of prompt block strings that would trigger "line too long"
|
||||
# flake8: noqa: E501
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
@@ -27,7 +31,7 @@ from backend.util.prompt import compress_prompt, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
fmt = TextFormatter()
|
||||
fmt = TextFormatter(autoescape=False)
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.AIML_API,
|
||||
@@ -97,6 +101,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
|
||||
@@ -204,13 +209,16 @@ MODEL_METADATA = {
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
"anthropic", 200000, 64000
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
@@ -382,7 +390,9 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_parallel_tool_calls_param(llm_model: LlmModel, parallel_tool_calls):
|
||||
def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.NOT_GIVEN
|
||||
@@ -393,8 +403,8 @@ async def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None,
|
||||
force_json_output: bool = False,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls=None,
|
||||
@@ -407,7 +417,7 @@ async def llm_call(
|
||||
credentials: The API key credentials to use.
|
||||
llm_model: The LLM model to use.
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
force_json_output: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
tools: The tools to use in the chat completion.
|
||||
ollama_host: The host for ollama to use.
|
||||
@@ -446,7 +456,7 @@ async def llm_call(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if json_format:
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
@@ -559,7 +569,7 @@ async def llm_call(
|
||||
raise ValueError("Groq does not support tools.")
|
||||
|
||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response_format = {"type": "json_object"} if force_json_output else None
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
@@ -717,7 +727,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
response_format = None
|
||||
if json_format:
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
@@ -780,6 +790,17 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to force the LLM to produce a JSON-only response. "
|
||||
"This can increase the block's reliability, "
|
||||
"but may also reduce the quality of the response "
|
||||
"because it prohibits the LLM from reasoning "
|
||||
"before providing its JSON response."
|
||||
),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
@@ -848,17 +869,18 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[""],
|
||||
response=json.dumps(
|
||||
{
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
}
|
||||
response=(
|
||||
'<json_output id="test123456">{\n'
|
||||
' "key1": "key1Value",\n'
|
||||
' "key2": "key2Value"\n'
|
||||
"}</json_output>"
|
||||
),
|
||||
tool_calls=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
reasoning=None,
|
||||
)
|
||||
),
|
||||
"get_collision_proof_output_tag_id": lambda *args: "test123456",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -867,9 +889,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
compress_prompt_to_fit: bool,
|
||||
max_tokens: int | None,
|
||||
force_json_output: bool = False,
|
||||
compress_prompt_to_fit: bool = True,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
) -> LLMResponse:
|
||||
@@ -882,8 +904,8 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
json_format=json_format,
|
||||
max_tokens=max_tokens,
|
||||
force_json_output=force_json_output,
|
||||
tools=tools,
|
||||
ollama_host=ollama_host,
|
||||
compress_prompt_to_fit=compress_prompt_to_fit,
|
||||
@@ -895,11 +917,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"Calling LLM with input data: {input_data}")
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
"""Removes indentation up to and including `|` from a multi-line prompt."""
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
@@ -908,28 +925,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
# Use a one-time unique tag to prevent collisions with user/LLM content
|
||||
output_tag_id = self.get_collision_proof_output_tag_id()
|
||||
output_tag_start = f'<json_output id="{output_tag_id}">'
|
||||
if input_data.expected_format:
|
||||
expected_format = [
|
||||
f"{json.dumps(k)}: {json.dumps(v)}"
|
||||
for k, v in input_data.expected_format.items()
|
||||
]
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = ",\n| ".join(expected_format)
|
||||
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply with pure JSON strictly following this JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. DO NOT include any additional text (e.g. markdown code block fences) outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
sys_prompt = self.response_format_instructions(
|
||||
input_data.expected_format,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
|
||||
@@ -947,18 +951,21 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
except JSONDecodeError as e:
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
try:
|
||||
llm_response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
|
||||
json_format=bool(input_data.expected_format),
|
||||
force_json_output=(
|
||||
input_data.force_json_output
|
||||
and bool(input_data.expected_format)
|
||||
),
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
@@ -973,30 +980,52 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
if input_data.expected_format:
|
||||
try:
|
||||
response_obj = json.loads(response_text)
|
||||
except JSONDecodeError as json_error:
|
||||
response_obj = self.get_json_from_response(
|
||||
response_text,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
except (ValueError, JSONDecodeError) as parse_error:
|
||||
censored_response = re.sub(r"[A-Za-z0-9]", "*", response_text)
|
||||
response_snippet = (
|
||||
f"{censored_response[:50]}...{censored_response[-30:]}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Error getting JSON from LLM response: {parse_error}\n\n"
|
||||
f"Response start+end: `{response_snippet}`"
|
||||
)
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
|
||||
indented_json_error = str(json_error).replace("\n", "\n|")
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your previous response could not be parsed as valid JSON:
|
||||
|
|
||||
|{indented_json_error}
|
||||
|
|
||||
|Please provide a valid JSON response that matches the expected format.
|
||||
"""
|
||||
error_feedback_message = self.invalid_response_feedback(
|
||||
parse_error,
|
||||
was_parseable=False,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle object response for `force_json_output`+`list_result`
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj:
|
||||
response_obj = response_obj.get("results", [])
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
if "results" in response_obj and isinstance(
|
||||
response_obj["results"], list
|
||||
):
|
||||
response_obj = response_obj["results"]
|
||||
else:
|
||||
error_feedback_message = (
|
||||
"Expected an array of objects in the 'results' key, "
|
||||
f"but got: {response_obj}"
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": response_text}
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
|
||||
validation_errors = "\n".join(
|
||||
[
|
||||
@@ -1022,12 +1051,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
return
|
||||
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your response did not match the expected format:
|
||||
|
|
||||
|{validation_errors}
|
||||
"""
|
||||
error_feedback_message = self.invalid_response_feedback(
|
||||
validation_errors,
|
||||
was_parseable=True,
|
||||
list_mode=input_data.list_result,
|
||||
pure_json_mode=input_data.force_json_output,
|
||||
output_tag_start=output_tag_start,
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback_message})
|
||||
else:
|
||||
@@ -1059,6 +1088,127 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
raise RuntimeError(error_feedback_message)
|
||||
|
||||
def response_format_instructions(
|
||||
self,
|
||||
expected_object_format: dict[str, str],
|
||||
*,
|
||||
list_mode: bool,
|
||||
pure_json_mode: bool,
|
||||
output_tag_start: str,
|
||||
) -> str:
|
||||
expected_output_format = json.dumps(expected_object_format, indent=2)
|
||||
output_type = "object" if not list_mode else "array"
|
||||
outer_output_type = "object" if pure_json_mode else output_type
|
||||
|
||||
if output_type == "array":
|
||||
indented_obj_format = expected_output_format.replace("\n", "\n ")
|
||||
expected_output_format = f"[\n {indented_obj_format},\n ...\n]"
|
||||
if pure_json_mode:
|
||||
indented_list_format = expected_output_format.replace("\n", "\n ")
|
||||
expected_output_format = (
|
||||
"{\n"
|
||||
' "reasoning": "... (optional)",\n' # for better performance
|
||||
f' "results": {indented_list_format}\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
# Preserve indentation in prompt
|
||||
expected_output_format = expected_output_format.replace("\n", "\n|")
|
||||
|
||||
# Prepare prompt
|
||||
if not pure_json_mode:
|
||||
expected_output_format = (
|
||||
f"{output_tag_start}\n{expected_output_format}\n</json_output>"
|
||||
)
|
||||
|
||||
instructions = f"""
|
||||
|In your response you MUST include a valid JSON {outer_output_type} strictly following this format:
|
||||
|{expected_output_format}
|
||||
|
|
||||
|If you cannot provide all the keys, you MUST provide an empty string for the values you cannot answer.
|
||||
""".strip()
|
||||
|
||||
if not pure_json_mode:
|
||||
instructions += f"""
|
||||
|
|
||||
|You MUST enclose your final JSON answer in {output_tag_start}...</json_output> tags, even if the user specifies a different tag.
|
||||
|There MUST be exactly ONE {output_tag_start}...</json_output> block in your response, which MUST ONLY contain the JSON {outer_output_type} and nothing else. Other text outside this block is allowed.
|
||||
""".strip()
|
||||
|
||||
return trim_prompt(instructions)
|
||||
|
||||
def invalid_response_feedback(
|
||||
self,
|
||||
error,
|
||||
*,
|
||||
was_parseable: bool,
|
||||
list_mode: bool,
|
||||
pure_json_mode: bool,
|
||||
output_tag_start: str,
|
||||
) -> str:
|
||||
outer_output_type = "object" if not list_mode or pure_json_mode else "array"
|
||||
|
||||
if was_parseable:
|
||||
complaint = f"Your previous response did not match the expected {outer_output_type} format."
|
||||
else:
|
||||
complaint = f"Your previous response did not contain a parseable JSON {outer_output_type}."
|
||||
|
||||
indented_parse_error = str(error).replace("\n", "\n|")
|
||||
|
||||
instruction = (
|
||||
f"Please provide a {output_tag_start}...</json_output> block containing a"
|
||||
if not pure_json_mode
|
||||
else "Please provide a"
|
||||
) + f" valid JSON {outer_output_type} that matches the expected format."
|
||||
|
||||
return trim_prompt(
|
||||
f"""
|
||||
|{complaint}
|
||||
|
|
||||
|{indented_parse_error}
|
||||
|
|
||||
|{instruction}
|
||||
"""
|
||||
)
|
||||
|
||||
def get_json_from_response(
|
||||
self, response_text: str, *, pure_json_mode: bool, output_tag_start: str
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
if pure_json_mode:
|
||||
# Handle pure JSON responses
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except JSONDecodeError as first_parse_error:
|
||||
# If that didn't work, try finding the { and } to deal with possible ```json fences etc.
|
||||
json_start = response_text.find("{")
|
||||
json_end = response_text.rfind("}")
|
||||
try:
|
||||
return json.loads(response_text[json_start : json_end + 1])
|
||||
except JSONDecodeError:
|
||||
# Raise the original error, as it's more likely to be relevant
|
||||
raise first_parse_error from None
|
||||
|
||||
if output_tag_start not in response_text:
|
||||
raise ValueError(
|
||||
"Response does not contain the expected "
|
||||
f"{output_tag_start}...</json_output> block."
|
||||
)
|
||||
json_output = (
|
||||
response_text.split(output_tag_start, 1)[1]
|
||||
.rsplit("</json_output>", 1)[0]
|
||||
.strip()
|
||||
)
|
||||
return json.loads(json_output)
|
||||
|
||||
def get_collision_proof_output_tag_id(self) -> str:
|
||||
return secrets.token_hex(8)
|
||||
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
"""Removes indentation up to and including `|` from a multi-line prompt."""
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
|
||||
class AITextGeneratorBlock(AIBlockBase):
|
||||
class Input(BlockSchema):
|
||||
|
||||
536
autogpt_platform/backend/backend/blocks/notion/_api.py
Normal file
536
autogpt_platform/backend/backend/blocks/notion/_api.py
Normal file
@@ -0,0 +1,536 @@
|
||||
"""
|
||||
Notion API helper functions and client for making authenticated requests.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
NOTION_VERSION = "2022-06-28"
|
||||
|
||||
|
||||
class NotionAPIException(Exception):
|
||||
"""Exception raised for Notion API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class NotionClient:
|
||||
"""Client for interacting with the Notion API."""
|
||||
|
||||
def __init__(self, credentials: OAuth2Credentials):
|
||||
self.credentials = credentials
|
||||
self.headers = {
|
||||
"Authorization": credentials.auth_header(),
|
||||
"Notion-Version": NOTION_VERSION,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.requests = Requests()
|
||||
|
||||
async def get_page(self, page_id: str) -> dict:
|
||||
"""
|
||||
Fetch a page by ID.
|
||||
|
||||
Args:
|
||||
page_id: The ID of the page to fetch.
|
||||
|
||||
Returns:
|
||||
The page object from Notion API.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
response = await self.requests.get(url, headers=self.headers)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to fetch page: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def get_blocks(self, block_id: str, recursive: bool = True) -> List[dict]:
|
||||
"""
|
||||
Fetch all blocks from a page or block.
|
||||
|
||||
Args:
|
||||
block_id: The ID of the page or block to fetch children from.
|
||||
recursive: Whether to fetch nested blocks recursively.
|
||||
|
||||
Returns:
|
||||
List of block objects.
|
||||
"""
|
||||
blocks = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
params = {"page_size": 100}
|
||||
if cursor:
|
||||
params["start_cursor"] = cursor
|
||||
|
||||
response = await self.requests.get(url, headers=self.headers, params=params)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to fetch blocks: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
current_blocks = data.get("results", [])
|
||||
|
||||
# If recursive, fetch children for blocks that have them
|
||||
if recursive:
|
||||
for block in current_blocks:
|
||||
if block.get("has_children"):
|
||||
block["children"] = await self.get_blocks(
|
||||
block["id"], recursive=True
|
||||
)
|
||||
|
||||
blocks.extend(current_blocks)
|
||||
|
||||
if not data.get("has_more"):
|
||||
break
|
||||
cursor = data.get("next_cursor")
|
||||
|
||||
return blocks
|
||||
|
||||
async def query_database(
|
||||
self,
|
||||
database_id: str,
|
||||
filter_obj: Optional[dict] = None,
|
||||
sorts: Optional[List[dict]] = None,
|
||||
page_size: int = 100,
|
||||
) -> dict:
|
||||
"""
|
||||
Query a database with optional filters and sorts.
|
||||
|
||||
Args:
|
||||
database_id: The ID of the database to query.
|
||||
filter_obj: Optional filter object for the query.
|
||||
sorts: Optional list of sort objects.
|
||||
page_size: Number of results per page.
|
||||
|
||||
Returns:
|
||||
Query results including pages and pagination info.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
|
||||
payload: Dict[str, Any] = {"page_size": page_size}
|
||||
if filter_obj:
|
||||
payload["filter"] = filter_obj
|
||||
if sorts:
|
||||
payload["sorts"] = sorts
|
||||
|
||||
response = await self.requests.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to query database: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def create_page(
|
||||
self,
|
||||
parent: dict,
|
||||
properties: dict,
|
||||
children: Optional[List[dict]] = None,
|
||||
icon: Optional[dict] = None,
|
||||
cover: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a new page.
|
||||
|
||||
Args:
|
||||
parent: Parent object (page_id or database_id).
|
||||
properties: Page properties.
|
||||
children: Optional list of block children.
|
||||
icon: Optional icon object.
|
||||
cover: Optional cover object.
|
||||
|
||||
Returns:
|
||||
The created page object.
|
||||
"""
|
||||
url = "https://api.notion.com/v1/pages"
|
||||
|
||||
payload: Dict[str, Any] = {"parent": parent, "properties": properties}
|
||||
|
||||
if children:
|
||||
payload["children"] = children
|
||||
if icon:
|
||||
payload["icon"] = icon
|
||||
if cover:
|
||||
payload["cover"] = cover
|
||||
|
||||
response = await self.requests.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to create page: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def update_page(self, page_id: str, properties: dict) -> dict:
|
||||
"""
|
||||
Update a page's properties.
|
||||
|
||||
Args:
|
||||
page_id: The ID of the page to update.
|
||||
properties: Properties to update.
|
||||
|
||||
Returns:
|
||||
The updated page object.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
|
||||
response = await self.requests.patch(
|
||||
url, headers=self.headers, json={"properties": properties}
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to update page: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def append_blocks(self, block_id: str, children: List[dict]) -> dict:
|
||||
"""
|
||||
Append blocks to a page or block.
|
||||
|
||||
Args:
|
||||
block_id: The ID of the page or block to append to.
|
||||
children: List of block objects to append.
|
||||
|
||||
Returns:
|
||||
Response with the created blocks.
|
||||
"""
|
||||
url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
|
||||
response = await self.requests.patch(
|
||||
url, headers=self.headers, json={"children": children}
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Failed to append blocks: {response.status} - {response.text()}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str = "",
|
||||
filter_obj: Optional[dict] = None,
|
||||
sort: Optional[dict] = None,
|
||||
page_size: int = 100,
|
||||
) -> dict:
|
||||
"""
|
||||
Search for pages and databases.
|
||||
|
||||
Args:
|
||||
query: Search query text.
|
||||
filter_obj: Optional filter object.
|
||||
sort: Optional sort object.
|
||||
page_size: Number of results per page.
|
||||
|
||||
Returns:
|
||||
Search results.
|
||||
"""
|
||||
url = "https://api.notion.com/v1/search"
|
||||
|
||||
payload: Dict[str, Any] = {"page_size": page_size}
|
||||
if query:
|
||||
payload["query"] = query
|
||||
if filter_obj:
|
||||
payload["filter"] = filter_obj
|
||||
if sort:
|
||||
payload["sort"] = sort
|
||||
|
||||
response = await self.requests.post(url, headers=self.headers, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
raise NotionAPIException(
|
||||
f"Search failed: {response.status} - {response.text()}", response.status
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
# Conversion helper functions
|
||||
|
||||
|
||||
def parse_rich_text(rich_text_array: List[dict]) -> str:
|
||||
"""
|
||||
Extract plain text from a Notion rich text array.
|
||||
|
||||
Args:
|
||||
rich_text_array: Array of rich text objects from Notion.
|
||||
|
||||
Returns:
|
||||
Plain text string.
|
||||
"""
|
||||
if not rich_text_array:
|
||||
return ""
|
||||
|
||||
text_parts = []
|
||||
for text_obj in rich_text_array:
|
||||
if "plain_text" in text_obj:
|
||||
text_parts.append(text_obj["plain_text"])
|
||||
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
def rich_text_to_markdown(rich_text_array: List[dict]) -> str:
|
||||
"""
|
||||
Convert Notion rich text array to markdown with formatting.
|
||||
|
||||
Args:
|
||||
rich_text_array: Array of rich text objects from Notion.
|
||||
|
||||
Returns:
|
||||
Markdown formatted string.
|
||||
"""
|
||||
if not rich_text_array:
|
||||
return ""
|
||||
|
||||
markdown_parts = []
|
||||
|
||||
for text_obj in rich_text_array:
|
||||
text = text_obj.get("plain_text", "")
|
||||
annotations = text_obj.get("annotations", {})
|
||||
|
||||
# Apply formatting based on annotations
|
||||
if annotations.get("code"):
|
||||
text = f"`{text}`"
|
||||
else:
|
||||
if annotations.get("bold"):
|
||||
text = f"**{text}**"
|
||||
if annotations.get("italic"):
|
||||
text = f"*{text}*"
|
||||
if annotations.get("strikethrough"):
|
||||
text = f"~~{text}~~"
|
||||
if annotations.get("underline"):
|
||||
text = f"<u>{text}</u>"
|
||||
|
||||
# Handle links
|
||||
if text_obj.get("href"):
|
||||
text = f"[{text}]({text_obj['href']})"
|
||||
|
||||
markdown_parts.append(text)
|
||||
|
||||
return "".join(markdown_parts)
|
||||
|
||||
|
||||
def block_to_markdown(block: dict, indent_level: int = 0) -> str:
|
||||
"""
|
||||
Convert a single Notion block to markdown.
|
||||
|
||||
Args:
|
||||
block: Block object from Notion API.
|
||||
indent_level: Current indentation level for nested blocks.
|
||||
|
||||
Returns:
|
||||
Markdown string representation of the block.
|
||||
"""
|
||||
block_type = block.get("type")
|
||||
indent = " " * indent_level
|
||||
markdown_lines = []
|
||||
|
||||
# Handle different block types
|
||||
if block_type == "paragraph":
|
||||
text = rich_text_to_markdown(block["paragraph"].get("rich_text", []))
|
||||
if text:
|
||||
markdown_lines.append(f"{indent}{text}")
|
||||
|
||||
elif block_type == "heading_1":
|
||||
text = parse_rich_text(block["heading_1"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}# {text}")
|
||||
|
||||
elif block_type == "heading_2":
|
||||
text = parse_rich_text(block["heading_2"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}## {text}")
|
||||
|
||||
elif block_type == "heading_3":
|
||||
text = parse_rich_text(block["heading_3"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}### {text}")
|
||||
|
||||
elif block_type == "bulleted_list_item":
|
||||
text = rich_text_to_markdown(block["bulleted_list_item"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}- {text}")
|
||||
|
||||
elif block_type == "numbered_list_item":
|
||||
text = rich_text_to_markdown(block["numbered_list_item"].get("rich_text", []))
|
||||
# Note: This is simplified - proper numbering would need context
|
||||
markdown_lines.append(f"{indent}1. {text}")
|
||||
|
||||
elif block_type == "to_do":
|
||||
text = rich_text_to_markdown(block["to_do"].get("rich_text", []))
|
||||
checked = "x" if block["to_do"].get("checked") else " "
|
||||
markdown_lines.append(f"{indent}- [{checked}] {text}")
|
||||
|
||||
elif block_type == "toggle":
|
||||
text = rich_text_to_markdown(block["toggle"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}<details>")
|
||||
markdown_lines.append(f"{indent}<summary>{text}</summary>")
|
||||
markdown_lines.append(f"{indent}")
|
||||
# Process children if they exist
|
||||
if block.get("children"):
|
||||
for child in block["children"]:
|
||||
child_markdown = block_to_markdown(child, indent_level + 1)
|
||||
if child_markdown:
|
||||
markdown_lines.append(child_markdown)
|
||||
markdown_lines.append(f"{indent}</details>")
|
||||
|
||||
elif block_type == "code":
|
||||
code = parse_rich_text(block["code"].get("rich_text", []))
|
||||
language = block["code"].get("language", "")
|
||||
markdown_lines.append(f"{indent}```{language}")
|
||||
markdown_lines.append(f"{indent}{code}")
|
||||
markdown_lines.append(f"{indent}```")
|
||||
|
||||
elif block_type == "quote":
|
||||
text = rich_text_to_markdown(block["quote"].get("rich_text", []))
|
||||
markdown_lines.append(f"{indent}> {text}")
|
||||
|
||||
elif block_type == "divider":
|
||||
markdown_lines.append(f"{indent}---")
|
||||
|
||||
elif block_type == "image":
|
||||
image = block["image"]
|
||||
url = image.get("external", {}).get("url") or image.get("file", {}).get(
|
||||
"url", ""
|
||||
)
|
||||
caption = parse_rich_text(image.get("caption", []))
|
||||
alt_text = caption if caption else "Image"
|
||||
markdown_lines.append(f"{indent}")
|
||||
if caption:
|
||||
markdown_lines.append(f"{indent}*{caption}*")
|
||||
|
||||
elif block_type == "video":
|
||||
video = block["video"]
|
||||
url = video.get("external", {}).get("url") or video.get("file", {}).get(
|
||||
"url", ""
|
||||
)
|
||||
caption = parse_rich_text(video.get("caption", []))
|
||||
markdown_lines.append(f"{indent}[Video]({url})")
|
||||
if caption:
|
||||
markdown_lines.append(f"{indent}*{caption}*")
|
||||
|
||||
elif block_type == "file":
|
||||
file = block["file"]
|
||||
url = file.get("external", {}).get("url") or file.get("file", {}).get("url", "")
|
||||
caption = parse_rich_text(file.get("caption", []))
|
||||
name = caption if caption else "File"
|
||||
markdown_lines.append(f"{indent}[{name}]({url})")
|
||||
|
||||
elif block_type == "bookmark":
|
||||
url = block["bookmark"].get("url", "")
|
||||
caption = parse_rich_text(block["bookmark"].get("caption", []))
|
||||
markdown_lines.append(f"{indent}[{caption if caption else url}]({url})")
|
||||
|
||||
elif block_type == "equation":
|
||||
expression = block["equation"].get("expression", "")
|
||||
markdown_lines.append(f"{indent}$${expression}$$")
|
||||
|
||||
elif block_type == "callout":
|
||||
text = rich_text_to_markdown(block["callout"].get("rich_text", []))
|
||||
icon = block["callout"].get("icon", {})
|
||||
if icon.get("emoji"):
|
||||
markdown_lines.append(f"{indent}> {icon['emoji']} {text}")
|
||||
else:
|
||||
markdown_lines.append(f"{indent}> ℹ️ {text}")
|
||||
|
||||
elif block_type == "child_page":
|
||||
title = block["child_page"].get("title", "Untitled")
|
||||
markdown_lines.append(f"{indent}📄 [{title}](notion://page/{block['id']})")
|
||||
|
||||
elif block_type == "child_database":
|
||||
title = block["child_database"].get("title", "Untitled Database")
|
||||
markdown_lines.append(f"{indent}🗂️ [{title}](notion://database/{block['id']})")
|
||||
|
||||
elif block_type == "table":
|
||||
# Tables are complex - for now just indicate there's a table
|
||||
markdown_lines.append(
|
||||
f"{indent}[Table with {block['table'].get('table_width', 0)} columns]"
|
||||
)
|
||||
|
||||
elif block_type == "column_list":
|
||||
# Process columns
|
||||
if block.get("children"):
|
||||
markdown_lines.append(f"{indent}<div style='display: flex'>")
|
||||
for column in block["children"]:
|
||||
markdown_lines.append(f"{indent}<div style='flex: 1'>")
|
||||
if column.get("children"):
|
||||
for child in column["children"]:
|
||||
child_markdown = block_to_markdown(child, indent_level + 1)
|
||||
if child_markdown:
|
||||
markdown_lines.append(child_markdown)
|
||||
markdown_lines.append(f"{indent}</div>")
|
||||
markdown_lines.append(f"{indent}</div>")
|
||||
|
||||
# Handle children for blocks that haven't been processed yet
|
||||
elif block.get("children") and block_type not in ["toggle", "column_list"]:
|
||||
for child in block["children"]:
|
||||
child_markdown = block_to_markdown(child, indent_level)
|
||||
if child_markdown:
|
||||
markdown_lines.append(child_markdown)
|
||||
|
||||
return "\n".join(markdown_lines) if markdown_lines else ""
|
||||
|
||||
|
||||
def blocks_to_markdown(blocks: List[dict]) -> str:
|
||||
"""
|
||||
Convert a list of Notion blocks to a markdown document.
|
||||
|
||||
Args:
|
||||
blocks: List of block objects from Notion API.
|
||||
|
||||
Returns:
|
||||
Complete markdown document as a string.
|
||||
"""
|
||||
markdown_parts = []
|
||||
|
||||
for i, block in enumerate(blocks):
|
||||
markdown = block_to_markdown(block)
|
||||
if markdown:
|
||||
markdown_parts.append(markdown)
|
||||
# Add spacing between top-level blocks (except lists)
|
||||
if i < len(blocks) - 1:
|
||||
next_type = blocks[i + 1].get("type", "")
|
||||
current_type = block.get("type", "")
|
||||
# Don't add extra spacing between list items
|
||||
list_types = {"bulleted_list_item", "numbered_list_item", "to_do"}
|
||||
if not (current_type in list_types and next_type in list_types):
|
||||
markdown_parts.append("")
|
||||
|
||||
return "\n".join(markdown_parts)
|
||||
|
||||
|
||||
def extract_page_title(page: dict) -> str:
|
||||
"""
|
||||
Extract the title from a Notion page object.
|
||||
|
||||
Args:
|
||||
page: Page object from Notion API.
|
||||
|
||||
Returns:
|
||||
Page title as a string.
|
||||
"""
|
||||
properties = page.get("properties", {})
|
||||
|
||||
# Find the title property (it has type "title")
|
||||
for prop_name, prop_value in properties.items():
|
||||
if prop_value.get("type") == "title":
|
||||
return parse_rich_text(prop_value.get("title", []))
|
||||
|
||||
return "Untitled"
|
||||
42
autogpt_platform/backend/backend/blocks/notion/_auth.py
Normal file
42
autogpt_platform/backend/backend/blocks/notion/_auth.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
NOTION_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.notion_client_id and secrets.notion_client_secret
|
||||
)
|
||||
|
||||
NotionCredentials = OAuth2Credentials
|
||||
NotionCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.NOTION], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def NotionCredentialsField() -> NotionCredentialsInput:
|
||||
"""Creates a Notion OAuth2 credentials field."""
|
||||
return CredentialsField(
|
||||
description="Connect your Notion account. Ensure the pages/databases are shared with the integration."
|
||||
)
|
||||
|
||||
|
||||
# Test credentials for Notion OAuth2
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="notion",
|
||||
access_token=SecretStr("test_access_token"),
|
||||
title="Mock Notion OAuth",
|
||||
scopes=["read_content", "insert_content", "update_content"],
|
||||
username="testuser",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
360
autogpt_platform/backend/backend/blocks/notion/create_page.py
Normal file
360
autogpt_platform/backend/backend/blocks/notion/create_page.py
Normal file
@@ -0,0 +1,360 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionCreatePageBlock(Block):
|
||||
"""Create a new page in Notion with content."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
parent_page_id: Optional[str] = SchemaField(
|
||||
description="Parent page ID to create the page under. Either this OR parent_database_id is required.",
|
||||
default=None,
|
||||
)
|
||||
parent_database_id: Optional[str] = SchemaField(
|
||||
description="Parent database ID to create the page in. Either this OR parent_page_id is required.",
|
||||
default=None,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title of the new page",
|
||||
)
|
||||
content: Optional[str] = SchemaField(
|
||||
description="Content for the page. Can be plain text or markdown - will be converted to Notion blocks.",
|
||||
default=None,
|
||||
)
|
||||
properties: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Additional properties for database pages (e.g., {'Status': 'In Progress', 'Priority': 'High'})",
|
||||
default=None,
|
||||
)
|
||||
icon_emoji: Optional[str] = SchemaField(
|
||||
description="Emoji to use as the page icon (e.g., '📄', '🚀')", default=None
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parent(self):
|
||||
"""Ensure either parent_page_id or parent_database_id is provided."""
|
||||
if not self.parent_page_id and not self.parent_database_id:
|
||||
raise ValueError(
|
||||
"Either parent_page_id or parent_database_id must be provided"
|
||||
)
|
||||
if self.parent_page_id and self.parent_database_id:
|
||||
raise ValueError(
|
||||
"Only one of parent_page_id or parent_database_id should be provided, not both"
|
||||
)
|
||||
return self
|
||||
|
||||
class Output(BlockSchema):
|
||||
page_id: str = SchemaField(description="ID of the created page.")
|
||||
page_url: str = SchemaField(description="URL of the created page.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c15febe0-66ce-4c6f-aebd-5ab351653804",
|
||||
description="Create a new page in Notion. Requires EITHER a parent_page_id OR parent_database_id. Supports markdown content.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionCreatePageBlock.Input,
|
||||
output_schema=NotionCreatePageBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"parent_page_id": "00000000-0000-0000-0000-000000000000",
|
||||
"title": "Test Page",
|
||||
"content": "This is test content.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("page_id", "12345678-1234-1234-1234-123456789012"),
|
||||
(
|
||||
"page_url",
|
||||
"https://notion.so/Test-Page-12345678123412341234123456789012",
|
||||
),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"create_page": lambda *args, **kwargs: (
|
||||
"12345678-1234-1234-1234-123456789012",
|
||||
"https://notion.so/Test-Page-12345678123412341234123456789012",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _markdown_to_blocks(content: str) -> List[dict]:
|
||||
"""Convert markdown content to Notion block objects."""
|
||||
if not content:
|
||||
return []
|
||||
|
||||
blocks = []
|
||||
lines = content.split("\n")
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
# Skip empty lines
|
||||
if not line.strip():
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Headings
|
||||
if line.startswith("### "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "heading_3",
|
||||
"heading_3": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[4:].strip()}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
elif line.startswith("## "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "heading_2",
|
||||
"heading_2": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[3:].strip()}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
elif line.startswith("# "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "heading_1",
|
||||
"heading_1": {
|
||||
"rich_text": [
|
||||
{"type": "text", "text": {"content": line[2:].strip()}}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Bullet points
|
||||
elif line.strip().startswith("- "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "bulleted_list_item",
|
||||
"bulleted_list_item": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": line.strip()[2:].strip()},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Numbered list
|
||||
elif line.strip() and line.strip()[0].isdigit() and ". " in line:
|
||||
content_start = line.find(". ") + 2
|
||||
blocks.append(
|
||||
{
|
||||
"type": "numbered_list_item",
|
||||
"numbered_list_item": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": line[content_start:].strip()},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Code block
|
||||
elif line.strip().startswith("```"):
|
||||
code_lines = []
|
||||
language = line[3:].strip() or "plain text"
|
||||
i += 1
|
||||
while i < len(lines) and not lines[i].strip().startswith("```"):
|
||||
code_lines.append(lines[i])
|
||||
i += 1
|
||||
blocks.append(
|
||||
{
|
||||
"type": "code",
|
||||
"code": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "\n".join(code_lines)},
|
||||
}
|
||||
],
|
||||
"language": language,
|
||||
},
|
||||
}
|
||||
)
|
||||
# Quote
|
||||
elif line.strip().startswith("> "):
|
||||
blocks.append(
|
||||
{
|
||||
"type": "quote",
|
||||
"quote": {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": line.strip()[2:].strip()},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
# Horizontal rule
|
||||
elif line.strip() in ["---", "***", "___"]:
|
||||
blocks.append({"type": "divider", "divider": {}})
|
||||
# Regular paragraph
|
||||
else:
|
||||
# Parse for basic markdown formatting
|
||||
text_content = line.strip()
|
||||
rich_text = []
|
||||
|
||||
# Simple bold/italic parsing (this is simplified)
|
||||
if "**" in text_content or "*" in text_content:
|
||||
# For now, just pass as plain text
|
||||
# A full implementation would parse and create proper annotations
|
||||
rich_text = [{"type": "text", "text": {"content": text_content}}]
|
||||
else:
|
||||
rich_text = [{"type": "text", "text": {"content": text_content}}]
|
||||
|
||||
blocks.append(
|
||||
{"type": "paragraph", "paragraph": {"rich_text": rich_text}}
|
||||
)
|
||||
|
||||
i += 1
|
||||
|
||||
return blocks
|
||||
|
||||
@staticmethod
|
||||
def _build_properties(
|
||||
title: str, additional_properties: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Build properties object for page creation."""
|
||||
properties: Dict[str, Any] = {
|
||||
"title": {"title": [{"type": "text", "text": {"content": title}}]}
|
||||
}
|
||||
|
||||
if additional_properties:
|
||||
for key, value in additional_properties.items():
|
||||
if key.lower() == "title":
|
||||
continue # Skip title as we already have it
|
||||
|
||||
# Try to intelligently map property types
|
||||
if isinstance(value, bool):
|
||||
properties[key] = {"checkbox": value}
|
||||
elif isinstance(value, (int, float)):
|
||||
properties[key] = {"number": value}
|
||||
elif isinstance(value, list):
|
||||
# Assume multi-select
|
||||
properties[key] = {
|
||||
"multi_select": [{"name": str(item)} for item in value]
|
||||
}
|
||||
elif isinstance(value, str):
|
||||
# Could be select, rich_text, or other types
|
||||
# For simplicity, try common patterns
|
||||
if key.lower() in ["status", "priority", "type", "category"]:
|
||||
properties[key] = {"select": {"name": value}}
|
||||
elif key.lower() in ["url", "link"]:
|
||||
properties[key] = {"url": value}
|
||||
elif key.lower() in ["email"]:
|
||||
properties[key] = {"email": value}
|
||||
else:
|
||||
properties[key] = {
|
||||
"rich_text": [{"type": "text", "text": {"content": value}}]
|
||||
}
|
||||
|
||||
return properties
|
||||
|
||||
@staticmethod
|
||||
async def create_page(
|
||||
credentials: OAuth2Credentials,
|
||||
title: str,
|
||||
parent_page_id: Optional[str] = None,
|
||||
parent_database_id: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
icon_emoji: Optional[str] = None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create a new Notion page.
|
||||
|
||||
Returns:
|
||||
Tuple of (page_id, page_url)
|
||||
"""
|
||||
if not parent_page_id and not parent_database_id:
|
||||
raise ValueError(
|
||||
"Either parent_page_id or parent_database_id must be provided"
|
||||
)
|
||||
if parent_page_id and parent_database_id:
|
||||
raise ValueError(
|
||||
"Only one of parent_page_id or parent_database_id should be provided, not both"
|
||||
)
|
||||
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Build parent object
|
||||
if parent_page_id:
|
||||
parent = {"type": "page_id", "page_id": parent_page_id}
|
||||
else:
|
||||
parent = {"type": "database_id", "database_id": parent_database_id}
|
||||
|
||||
# Build properties
|
||||
page_properties = NotionCreatePageBlock._build_properties(title, properties)
|
||||
|
||||
# Convert content to blocks if provided
|
||||
children = None
|
||||
if content:
|
||||
children = NotionCreatePageBlock._markdown_to_blocks(content)
|
||||
|
||||
# Build icon if provided
|
||||
icon = None
|
||||
if icon_emoji:
|
||||
icon = {"type": "emoji", "emoji": icon_emoji}
|
||||
|
||||
# Create the page
|
||||
result = await client.create_page(
|
||||
parent=parent, properties=page_properties, children=children, icon=icon
|
||||
)
|
||||
|
||||
page_id = result.get("id", "")
|
||||
page_url = result.get("url", "")
|
||||
|
||||
if not page_id or not page_url:
|
||||
raise ValueError("Failed to get page ID or URL from Notion response")
|
||||
|
||||
return page_id, page_url
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
page_id, page_url = await self.create_page(
|
||||
credentials,
|
||||
input_data.title,
|
||||
input_data.parent_page_id,
|
||||
input_data.parent_database_id,
|
||||
input_data.content,
|
||||
input_data.properties,
|
||||
input_data.icon_emoji,
|
||||
)
|
||||
yield "page_id", page_id
|
||||
yield "page_url", page_url
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
285
autogpt_platform/backend/backend/blocks/notion/read_database.py
Normal file
285
autogpt_platform/backend/backend/blocks/notion/read_database.py
Normal file
@@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient, parse_rich_text
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionReadDatabaseBlock(Block):
|
||||
"""Query a Notion database and retrieve entries with their properties."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
database_id: str = SchemaField(
|
||||
description="Notion database ID. Must be accessible by the connected integration.",
|
||||
)
|
||||
filter_property: Optional[str] = SchemaField(
|
||||
description="Property name to filter by (e.g., 'Status', 'Priority')",
|
||||
default=None,
|
||||
)
|
||||
filter_value: Optional[str] = SchemaField(
|
||||
description="Value to filter for in the specified property", default=None
|
||||
)
|
||||
sort_property: Optional[str] = SchemaField(
|
||||
description="Property name to sort by", default=None
|
||||
)
|
||||
sort_direction: Optional[str] = SchemaField(
|
||||
description="Sort direction: 'ascending' or 'descending'",
|
||||
default="ascending",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of entries to retrieve",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
entries: List[Dict[str, Any]] = SchemaField(
|
||||
description="List of database entries with their properties."
|
||||
)
|
||||
entry: Dict[str, Any] = SchemaField(
|
||||
description="Individual database entry (yields one per entry found)."
|
||||
)
|
||||
entry_ids: List[str] = SchemaField(
|
||||
description="List of entry IDs for batch operations."
|
||||
)
|
||||
entry_id: str = SchemaField(
|
||||
description="Individual entry ID (yields one per entry found)."
|
||||
)
|
||||
count: int = SchemaField(description="Number of entries retrieved.")
|
||||
database_title: str = SchemaField(description="Title of the database.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fcd53135-88c9-4ba3-be50-cc6936286e6c",
|
||||
description="Query a Notion database with optional filtering and sorting, returning structured entries.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionReadDatabaseBlock.Input,
|
||||
output_schema=NotionReadDatabaseBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"database_id": "00000000-0000-0000-0000-000000000000",
|
||||
"limit": 10,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"entries",
|
||||
[{"Name": "Test Entry", "Status": "Active", "_id": "test-123"}],
|
||||
),
|
||||
("entry_ids", ["test-123"]),
|
||||
(
|
||||
"entry",
|
||||
{"Name": "Test Entry", "Status": "Active", "_id": "test-123"},
|
||||
),
|
||||
("entry_id", "test-123"),
|
||||
("count", 1),
|
||||
("database_title", "Test Database"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"query_database": lambda *args, **kwargs: (
|
||||
[{"Name": "Test Entry", "Status": "Active", "_id": "test-123"}],
|
||||
1,
|
||||
"Test Database",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_property_value(prop: dict) -> Any:
|
||||
"""Parse a Notion property value into a simple Python type."""
|
||||
prop_type = prop.get("type")
|
||||
|
||||
if prop_type == "title":
|
||||
return parse_rich_text(prop.get("title", []))
|
||||
elif prop_type == "rich_text":
|
||||
return parse_rich_text(prop.get("rich_text", []))
|
||||
elif prop_type == "number":
|
||||
return prop.get("number")
|
||||
elif prop_type == "select":
|
||||
select = prop.get("select")
|
||||
return select.get("name") if select else None
|
||||
elif prop_type == "multi_select":
|
||||
return [item.get("name") for item in prop.get("multi_select", [])]
|
||||
elif prop_type == "date":
|
||||
date = prop.get("date")
|
||||
if date:
|
||||
return date.get("start")
|
||||
return None
|
||||
elif prop_type == "checkbox":
|
||||
return prop.get("checkbox", False)
|
||||
elif prop_type == "url":
|
||||
return prop.get("url")
|
||||
elif prop_type == "email":
|
||||
return prop.get("email")
|
||||
elif prop_type == "phone_number":
|
||||
return prop.get("phone_number")
|
||||
elif prop_type == "people":
|
||||
return [
|
||||
person.get("name", person.get("id"))
|
||||
for person in prop.get("people", [])
|
||||
]
|
||||
elif prop_type == "files":
|
||||
files = prop.get("files", [])
|
||||
return [
|
||||
f.get(
|
||||
"name",
|
||||
f.get("external", {}).get("url", f.get("file", {}).get("url")),
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
elif prop_type == "relation":
|
||||
return [rel.get("id") for rel in prop.get("relation", [])]
|
||||
elif prop_type == "formula":
|
||||
formula = prop.get("formula", {})
|
||||
return formula.get(formula.get("type"))
|
||||
elif prop_type == "rollup":
|
||||
rollup = prop.get("rollup", {})
|
||||
return rollup.get(rollup.get("type"))
|
||||
elif prop_type == "created_time":
|
||||
return prop.get("created_time")
|
||||
elif prop_type == "created_by":
|
||||
return prop.get("created_by", {}).get(
|
||||
"name", prop.get("created_by", {}).get("id")
|
||||
)
|
||||
elif prop_type == "last_edited_time":
|
||||
return prop.get("last_edited_time")
|
||||
elif prop_type == "last_edited_by":
|
||||
return prop.get("last_edited_by", {}).get(
|
||||
"name", prop.get("last_edited_by", {}).get("id")
|
||||
)
|
||||
else:
|
||||
# Return the raw value for unknown types
|
||||
return prop
|
||||
|
||||
@staticmethod
|
||||
def _build_filter(property_name: str, value: str) -> dict:
|
||||
"""Build a simple filter object for a property."""
|
||||
# This is a simplified filter - in reality, you'd need to know the property type
|
||||
# For now, we'll try common filter types
|
||||
return {
|
||||
"or": [
|
||||
{"property": property_name, "rich_text": {"contains": value}},
|
||||
{"property": property_name, "title": {"contains": value}},
|
||||
{"property": property_name, "select": {"equals": value}},
|
||||
{"property": property_name, "multi_select": {"contains": value}},
|
||||
]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def query_database(
|
||||
credentials: OAuth2Credentials,
|
||||
database_id: str,
|
||||
filter_property: Optional[str] = None,
|
||||
filter_value: Optional[str] = None,
|
||||
sort_property: Optional[str] = None,
|
||||
sort_direction: str = "ascending",
|
||||
limit: int = 100,
|
||||
) -> tuple[List[Dict[str, Any]], int, str]:
|
||||
"""
|
||||
Query a Notion database and parse the results.
|
||||
|
||||
Returns:
|
||||
Tuple of (entries_list, count, database_title)
|
||||
"""
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Build filter if specified
|
||||
filter_obj = None
|
||||
if filter_property and filter_value:
|
||||
filter_obj = NotionReadDatabaseBlock._build_filter(
|
||||
filter_property, filter_value
|
||||
)
|
||||
|
||||
# Build sorts if specified
|
||||
sorts = None
|
||||
if sort_property:
|
||||
sorts = [{"property": sort_property, "direction": sort_direction}]
|
||||
|
||||
# Query the database
|
||||
result = await client.query_database(
|
||||
database_id, filter_obj=filter_obj, sorts=sorts, page_size=limit
|
||||
)
|
||||
|
||||
# Parse the entries
|
||||
entries = []
|
||||
for page in result.get("results", []):
|
||||
entry = {}
|
||||
properties = page.get("properties", {})
|
||||
|
||||
for prop_name, prop_value in properties.items():
|
||||
entry[prop_name] = NotionReadDatabaseBlock._parse_property_value(
|
||||
prop_value
|
||||
)
|
||||
|
||||
# Add metadata
|
||||
entry["_id"] = page.get("id")
|
||||
entry["_url"] = page.get("url")
|
||||
entry["_created_time"] = page.get("created_time")
|
||||
entry["_last_edited_time"] = page.get("last_edited_time")
|
||||
|
||||
entries.append(entry)
|
||||
|
||||
# Get database title (we need to make a separate call for this)
|
||||
try:
|
||||
database_url = f"https://api.notion.com/v1/databases/{database_id}"
|
||||
db_response = await client.requests.get(
|
||||
database_url, headers=client.headers
|
||||
)
|
||||
if db_response.ok:
|
||||
db_data = db_response.json()
|
||||
db_title = parse_rich_text(db_data.get("title", []))
|
||||
else:
|
||||
db_title = "Unknown Database"
|
||||
except Exception:
|
||||
db_title = "Unknown Database"
|
||||
|
||||
return entries, len(entries), db_title
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
entries, count, db_title = await self.query_database(
|
||||
credentials,
|
||||
input_data.database_id,
|
||||
input_data.filter_property,
|
||||
input_data.filter_value,
|
||||
input_data.sort_property,
|
||||
input_data.sort_direction or "ascending",
|
||||
input_data.limit,
|
||||
)
|
||||
# Yield the complete list for batch operations
|
||||
yield "entries", entries
|
||||
|
||||
# Extract and yield IDs as a list for batch operations
|
||||
entry_ids = [entry["_id"] for entry in entries if "_id" in entry]
|
||||
yield "entry_ids", entry_ids
|
||||
|
||||
# Yield each individual entry and its ID for single connections
|
||||
for entry in entries:
|
||||
yield "entry", entry
|
||||
if "_id" in entry:
|
||||
yield "entry_id", entry["_id"]
|
||||
|
||||
yield "count", count
|
||||
yield "database_title", db_title
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
64
autogpt_platform/backend/backend/blocks/notion/read_page.py
Normal file
64
autogpt_platform/backend/backend/blocks/notion/read_page.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionReadPageBlock(Block):
|
||||
"""Read a Notion page by ID and return its raw JSON."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
page_id: str = SchemaField(
|
||||
description="Notion page ID. Must be accessible by the connected integration. You can get this from the page URL notion.so/A-Page-586edd711467478da59fe3ce29a1ffab would be 586edd711467478da59fe35e29a1ffab",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
page: dict = SchemaField(description="Raw Notion page JSON.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5246cc1d-34b7-452b-8fc5-3fb25fd8f542",
|
||||
description="Read a Notion page by its ID and return its raw JSON.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionReadPageBlock.Input,
|
||||
output_schema=NotionReadPageBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"page_id": "00000000-0000-0000-0000-000000000000",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[("page", dict)],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"get_page": lambda *args, **kwargs: {"object": "page", "id": "mocked"}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_page(credentials: OAuth2Credentials, page_id: str) -> dict:
|
||||
client = NotionClient(credentials)
|
||||
return await client.get_page(page_id)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
page = await self.get_page(credentials, input_data.page_id)
|
||||
yield "page", page
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient, blocks_to_markdown, extract_page_title
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionReadPageMarkdownBlock(Block):
|
||||
"""Read a Notion page and convert it to clean Markdown format."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
page_id: str = SchemaField(
|
||||
description="Notion page ID. Must be accessible by the connected integration. You can get this from the page URL notion.so/A-Page-586edd711467478da59fe35e29a1ffab would be 586edd711467478da59fe35e29a1ffab",
|
||||
)
|
||||
include_title: bool = SchemaField(
|
||||
description="Whether to include the page title as a header in the markdown",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
markdown: str = SchemaField(description="Page content in Markdown format.")
|
||||
title: str = SchemaField(description="Page title.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d1312c4d-fae2-4e70-893d-f4d07cce1d4e",
|
||||
description="Read a Notion page and convert it to Markdown format with proper formatting for headings, lists, links, and rich text.",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=NotionReadPageMarkdownBlock.Input,
|
||||
output_schema=NotionReadPageMarkdownBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"page_id": "00000000-0000-0000-0000-000000000000",
|
||||
"include_title": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("markdown", "# Test Page\n\nThis is test content."),
|
||||
("title", "Test Page"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"get_page_markdown": lambda *args, **kwargs: (
|
||||
"# Test Page\n\nThis is test content.",
|
||||
"Test Page",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_page_markdown(
|
||||
credentials: OAuth2Credentials, page_id: str, include_title: bool = True
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Get a Notion page and convert it to markdown.
|
||||
|
||||
Args:
|
||||
credentials: OAuth2 credentials for Notion.
|
||||
page_id: The ID of the page to fetch.
|
||||
include_title: Whether to include the page title in the markdown.
|
||||
|
||||
Returns:
|
||||
Tuple of (markdown_content, title)
|
||||
"""
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Get page metadata
|
||||
page = await client.get_page(page_id)
|
||||
title = extract_page_title(page)
|
||||
|
||||
# Get all blocks from the page
|
||||
blocks = await client.get_blocks(page_id, recursive=True)
|
||||
|
||||
# Convert blocks to markdown
|
||||
content_markdown = blocks_to_markdown(blocks)
|
||||
|
||||
# Combine title and content if requested
|
||||
if include_title and title:
|
||||
full_markdown = f"# {title}\n\n{content_markdown}"
|
||||
else:
|
||||
full_markdown = content_markdown
|
||||
|
||||
return full_markdown, title
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
markdown, title = await self.get_page_markdown(
|
||||
credentials, input_data.page_id, input_data.include_title
|
||||
)
|
||||
yield "markdown", markdown
|
||||
yield "title", title
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
225
autogpt_platform/backend/backend/blocks/notion/search.py
Normal file
225
autogpt_platform/backend/backend/blocks/notion/search.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import NotionClient, extract_page_title, parse_rich_text
|
||||
from ._auth import (
|
||||
NOTION_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
NotionCredentialsField,
|
||||
NotionCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class NotionSearchResult(BaseModel):
|
||||
"""Typed model for Notion search results."""
|
||||
|
||||
id: str
|
||||
type: str # 'page' or 'database'
|
||||
title: str
|
||||
url: str
|
||||
created_time: Optional[str] = None
|
||||
last_edited_time: Optional[str] = None
|
||||
parent_type: Optional[str] = None # 'page', 'database', or 'workspace'
|
||||
parent_id: Optional[str] = None
|
||||
icon: Optional[str] = None # emoji icon if present
|
||||
is_inline: Optional[bool] = None # for databases only
|
||||
|
||||
|
||||
class NotionSearchBlock(Block):
|
||||
"""Search across your Notion workspace for pages and databases."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: NotionCredentialsInput = NotionCredentialsField()
|
||||
query: str = SchemaField(
|
||||
description="Search query text. Leave empty to get all accessible pages/databases.",
|
||||
default="",
|
||||
)
|
||||
filter_type: Optional[str] = SchemaField(
|
||||
description="Filter results by type: 'page' or 'database'. Leave empty for both.",
|
||||
default=None,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=20, ge=1, le=100
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: List[NotionSearchResult] = SchemaField(
|
||||
description="List of search results with title, type, URL, and metadata."
|
||||
)
|
||||
result: NotionSearchResult = SchemaField(
|
||||
description="Individual search result (yields one per result found)."
|
||||
)
|
||||
result_ids: List[str] = SchemaField(
|
||||
description="List of IDs from search results for batch operations."
|
||||
)
|
||||
count: int = SchemaField(description="Number of results found.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="313515dd-9848-46ea-9cd6-3c627c892c56",
|
||||
description="Search your Notion workspace for pages and databases by text query.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.SEARCH},
|
||||
input_schema=NotionSearchBlock.Input,
|
||||
output_schema=NotionSearchBlock.Output,
|
||||
disabled=not NOTION_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"query": "project",
|
||||
"limit": 5,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"results",
|
||||
[
|
||||
NotionSearchResult(
|
||||
id="123",
|
||||
type="page",
|
||||
title="Project Plan",
|
||||
url="https://notion.so/Project-Plan-123",
|
||||
)
|
||||
],
|
||||
),
|
||||
("result_ids", ["123"]),
|
||||
(
|
||||
"result",
|
||||
NotionSearchResult(
|
||||
id="123",
|
||||
type="page",
|
||||
title="Project Plan",
|
||||
url="https://notion.so/Project-Plan-123",
|
||||
),
|
||||
),
|
||||
("count", 1),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"search_workspace": lambda *args, **kwargs: (
|
||||
[
|
||||
NotionSearchResult(
|
||||
id="123",
|
||||
type="page",
|
||||
title="Project Plan",
|
||||
url="https://notion.so/Project-Plan-123",
|
||||
)
|
||||
],
|
||||
1,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_workspace(
|
||||
credentials: OAuth2Credentials,
|
||||
query: str = "",
|
||||
filter_type: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
) -> tuple[List[NotionSearchResult], int]:
|
||||
"""
|
||||
Search the Notion workspace.
|
||||
|
||||
Returns:
|
||||
Tuple of (results_list, count)
|
||||
"""
|
||||
client = NotionClient(credentials)
|
||||
|
||||
# Build filter if type is specified
|
||||
filter_obj = None
|
||||
if filter_type:
|
||||
filter_obj = {"property": "object", "value": filter_type}
|
||||
|
||||
# Execute search
|
||||
response = await client.search(
|
||||
query=query, filter_obj=filter_obj, page_size=limit
|
||||
)
|
||||
|
||||
# Parse results
|
||||
results = []
|
||||
for item in response.get("results", []):
|
||||
result_data = {
|
||||
"id": item.get("id", ""),
|
||||
"type": item.get("object", ""),
|
||||
"url": item.get("url", ""),
|
||||
"created_time": item.get("created_time"),
|
||||
"last_edited_time": item.get("last_edited_time"),
|
||||
"title": "", # Will be set below
|
||||
}
|
||||
|
||||
# Extract title based on type
|
||||
if item.get("object") == "page":
|
||||
# For pages, get the title from properties
|
||||
result_data["title"] = extract_page_title(item)
|
||||
|
||||
# Add parent info
|
||||
parent = item.get("parent", {})
|
||||
if parent.get("type") == "page_id":
|
||||
result_data["parent_type"] = "page"
|
||||
result_data["parent_id"] = parent.get("page_id")
|
||||
elif parent.get("type") == "database_id":
|
||||
result_data["parent_type"] = "database"
|
||||
result_data["parent_id"] = parent.get("database_id")
|
||||
elif parent.get("type") == "workspace":
|
||||
result_data["parent_type"] = "workspace"
|
||||
|
||||
# Add icon if present
|
||||
icon = item.get("icon")
|
||||
if icon and icon.get("type") == "emoji":
|
||||
result_data["icon"] = icon.get("emoji")
|
||||
|
||||
elif item.get("object") == "database":
|
||||
# For databases, get title from the title array
|
||||
result_data["title"] = parse_rich_text(item.get("title", []))
|
||||
|
||||
# Add database-specific metadata
|
||||
result_data["is_inline"] = item.get("is_inline", False)
|
||||
|
||||
# Add parent info
|
||||
parent = item.get("parent", {})
|
||||
if parent.get("type") == "page_id":
|
||||
result_data["parent_type"] = "page"
|
||||
result_data["parent_id"] = parent.get("page_id")
|
||||
elif parent.get("type") == "workspace":
|
||||
result_data["parent_type"] = "workspace"
|
||||
|
||||
# Add icon if present
|
||||
icon = item.get("icon")
|
||||
if icon and icon.get("type") == "emoji":
|
||||
result_data["icon"] = icon.get("emoji")
|
||||
|
||||
results.append(NotionSearchResult(**result_data))
|
||||
|
||||
return results, len(results)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, count = await self.search_workspace(
|
||||
credentials, input_data.query, input_data.filter_type, input_data.limit
|
||||
)
|
||||
|
||||
# Yield the complete list for batch operations
|
||||
yield "results", results
|
||||
|
||||
# Extract and yield IDs as a list for batch operations
|
||||
result_ids = [r.id for r in results]
|
||||
yield "result_ids", result_ids
|
||||
|
||||
# Yield each individual result for single connections
|
||||
for result in results:
|
||||
yield "result", result
|
||||
|
||||
yield "count", count
|
||||
except Exception as e:
|
||||
yield "error", str(e) if str(e) else "Unknown error"
|
||||
@@ -519,35 +519,121 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
# Use retry decorator for LLM calls with validation
|
||||
from backend.util.retry import create_retry_decorator
|
||||
|
||||
# Create retry decorator that excludes ValueError from retry (for non-LLM errors)
|
||||
llm_retry = create_retry_decorator(
|
||||
max_attempts=input_data.retry,
|
||||
exclude_exceptions=(), # Don't exclude ValueError - we want to retry validation failures
|
||||
context="SmartDecisionMaker LLM call",
|
||||
)
|
||||
|
||||
# Track LLM usage stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
llm_call_count=1,
|
||||
@llm_retry
|
||||
async def call_llm_with_validation():
|
||||
response = await llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
)
|
||||
|
||||
# Track LLM usage stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
llm_call_count=1,
|
||||
)
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
return response, None # No tool calls, return response
|
||||
|
||||
# Validate all tool calls before proceeding
|
||||
validation_errors = []
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# Get parameters schema from tool definition
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
parameters = tool_def["function"]["parameters"]
|
||||
expected_args = parameters.get("properties", {})
|
||||
required_params = set(parameters.get("required", []))
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
required_params = set()
|
||||
|
||||
# Validate tool call arguments
|
||||
provided_args = set(tool_args.keys())
|
||||
expected_args_set = set(expected_args.keys())
|
||||
|
||||
# Check for unexpected arguments (typos)
|
||||
unexpected_args = provided_args - expected_args_set
|
||||
# Only check for missing REQUIRED parameters
|
||||
missing_required_args = required_params - provided_args
|
||||
|
||||
if unexpected_args or missing_required_args:
|
||||
error_msg = f"Tool call '{tool_name}' has parameter errors:"
|
||||
if unexpected_args:
|
||||
error_msg += f" Unknown parameters: {sorted(unexpected_args)}."
|
||||
if missing_required_args:
|
||||
error_msg += f" Missing required parameters: {sorted(missing_required_args)}."
|
||||
error_msg += f" Expected parameters: {sorted(expected_args_set)}."
|
||||
if required_params:
|
||||
error_msg += f" Required parameters: {sorted(required_params)}."
|
||||
validation_errors.append(error_msg)
|
||||
|
||||
# If validation failed, add feedback and raise for retry
|
||||
if validation_errors:
|
||||
# Add the failed response to conversation
|
||||
prompt.append(response.raw_response)
|
||||
|
||||
# Add error feedback for retry
|
||||
error_feedback = (
|
||||
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
|
||||
+ "\n".join(f"- {error}" for error in validation_errors)
|
||||
+ "\n\nPlease make sure to use the exact parameter names as specified in the function schema."
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback})
|
||||
|
||||
raise ValueError(
|
||||
f"Tool call validation failed: {'; '.join(validation_errors)}"
|
||||
)
|
||||
|
||||
return response, validation_errors
|
||||
|
||||
# Call the LLM with retry logic
|
||||
response, validation_errors = await call_llm_with_validation()
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
return
|
||||
|
||||
# If we get here, validation passed - yield tool outputs
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
# Get expected arguments (already validated above)
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
@@ -556,7 +642,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
@@ -564,14 +649,11 @@ class SmartDecisionMakerBlock(Block):
|
||||
):
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = tool_args.keys()
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
|
||||
# Yield provided arguments and None for missing ones
|
||||
# Yield provided arguments, use .get() for optional parameters
|
||||
for arg_name in expected_args:
|
||||
if arg_name in tool_args:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args.get(arg_name)
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
|
||||
@@ -30,7 +30,6 @@ class TestLLMStatsTracking:
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
json_format=False,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
@@ -42,6 +41,8 @@ class TestLLMStatsTracking:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_structured_response_block_tracks_stats(self):
|
||||
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
@@ -51,7 +52,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"key1": "value1", "key2": "value2"}',
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=15,
|
||||
completion_tokens=25,
|
||||
@@ -69,10 +70,12 @@ class TestLLMStatsTracking:
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats
|
||||
assert block.execution_stats.input_token_count == 15
|
||||
@@ -143,7 +146,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"wrong": "format"}',
|
||||
response='<json_output id="test123456">{"wrong": "format"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=15,
|
||||
@@ -154,7 +157,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"key1": "value1", "key2": "value2"}',
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=25,
|
||||
@@ -173,10 +176,12 @@ class TestLLMStatsTracking:
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats - should accumulate both calls
|
||||
# For 2 attempts: attempt 1 (failed) + attempt 2 (success) = 2 total
|
||||
@@ -269,7 +274,8 @@ class TestLLMStatsTracking:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='{"summary": "Test chunk summary"}', tool_calls=None
|
||||
content='<json_output id="test123456">{"summary": "Test chunk summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
]
|
||||
@@ -277,7 +283,7 @@ class TestLLMStatsTracking:
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=MagicMock(
|
||||
content='{"final_summary": "Test final summary"}',
|
||||
content='<json_output id="test123456">{"final_summary": "Test final summary"}</json_output>',
|
||||
tool_calls=None,
|
||||
)
|
||||
)
|
||||
@@ -298,11 +304,13 @@ class TestLLMStatsTracking:
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
print(f"Actual calls made: {call_count}")
|
||||
print(f"Block stats: {block.execution_stats}")
|
||||
@@ -457,7 +465,7 @@ class TestLLMStatsTracking:
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='{"result": "test"}',
|
||||
response='<json_output id="test123456">{"result": "test"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
@@ -476,10 +484,12 @@ class TestLLMStatsTracking:
|
||||
|
||||
# Run the block
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
# Mock secrets.token_hex to return consistent ID
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for output_name, output_data in block.run(
|
||||
input_data, credentials=llm.TEST_CREDENTIALS
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Block finished - now grab and assert stats
|
||||
assert block.execution_stats is not None
|
||||
|
||||
@@ -249,3 +249,232 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Verify outputs
|
||||
assert "finished" in outputs # Should have finished since no tool calls
|
||||
assert outputs["finished"] == "I need to think about this."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_parameter_validation():
|
||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Mock tool functions with specific parameter schema
|
||||
mock_tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_keywords",
|
||||
"description": "Search for keywords with difficulty filtering",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"max_keyword_difficulty": {
|
||||
"type": "integer",
|
||||
"description": "Maximum keyword difficulty (required)",
|
||||
},
|
||||
"optional_param": {
|
||||
"type": "string",
|
||||
"description": "Optional parameter with default",
|
||||
"default": "default_value",
|
||||
},
|
||||
},
|
||||
"required": ["query", "max_keyword_difficulty"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test case 1: Tool call with TYPO in parameter name (should retry and eventually fail)
|
||||
mock_tool_call_with_typo = MagicMock()
|
||||
mock_tool_call_with_typo.function.name = "search_keywords"
|
||||
mock_tool_call_with_typo.function.arguments = '{"query": "test", "maximum_keyword_difficulty": 50}' # TYPO: maximum instead of max
|
||||
|
||||
mock_response_with_typo = MagicMock()
|
||||
mock_response_with_typo.response = None
|
||||
mock_response_with_typo.tool_calls = [mock_tool_call_with_typo]
|
||||
mock_response_with_typo.prompt_tokens = 50
|
||||
mock_response_with_typo.completion_tokens = 25
|
||||
mock_response_with_typo.reasoning = None
|
||||
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_with_typo
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
)
|
||||
|
||||
# Should raise ValueError after retries due to typo'd parameter name
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify error message contains details about the typo
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call validation failed" in error_msg
|
||||
assert "Unknown parameters: ['maximum_keyword_difficulty']" in error_msg
|
||||
|
||||
# Verify that LLM was called the expected number of times (retries)
|
||||
assert mock_llm_call.call_count == 2 # Should retry based on input_data.retry
|
||||
|
||||
# Test case 2: Tool call missing REQUIRED parameter (should raise ValueError)
|
||||
mock_tool_call_missing_required = MagicMock()
|
||||
mock_tool_call_missing_required.function.name = "search_keywords"
|
||||
mock_tool_call_missing_required.function.arguments = (
|
||||
'{"query": "test"}' # Missing required max_keyword_difficulty
|
||||
)
|
||||
|
||||
mock_response_missing_required = MagicMock()
|
||||
mock_response_missing_required.response = None
|
||||
mock_response_missing_required.tool_calls = [mock_tool_call_missing_required]
|
||||
mock_response_missing_required.prompt_tokens = 50
|
||||
mock_response_missing_required.completion_tokens = 25
|
||||
mock_response_missing_required.reasoning = None
|
||||
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_missing_required
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should raise ValueError due to missing required parameter
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Tool call 'search_keywords' has parameter errors" in error_msg
|
||||
assert "Missing required parameters: ['max_keyword_difficulty']" in error_msg
|
||||
|
||||
# Test case 3: Valid tool call with OPTIONAL parameter missing (should succeed)
|
||||
mock_tool_call_valid = MagicMock()
|
||||
mock_tool_call_valid.function.name = "search_keywords"
|
||||
mock_tool_call_valid.function.arguments = '{"query": "test", "max_keyword_difficulty": 50}' # optional_param missing, but that's OK
|
||||
|
||||
mock_response_valid = MagicMock()
|
||||
mock_response_valid.response = None
|
||||
mock_response_valid.tool_calls = [mock_tool_call_valid]
|
||||
mock_response_valid.prompt_tokens = 50
|
||||
mock_response_valid.completion_tokens = 25
|
||||
mock_response_valid.reasoning = None
|
||||
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_valid
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed - optional parameter missing is OK
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify tool outputs were generated correctly
|
||||
assert "tools_^_search_keywords_~_query" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert "tools_^_search_keywords_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
# Optional parameter should be None when not provided
|
||||
assert "tools_^_search_keywords_~_optional_param" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] is None
|
||||
|
||||
# Test case 4: Valid tool call with ALL parameters (should succeed)
|
||||
mock_tool_call_all_params = MagicMock()
|
||||
mock_tool_call_all_params.function.name = "search_keywords"
|
||||
mock_tool_call_all_params.function.arguments = '{"query": "test", "max_keyword_difficulty": 50, "optional_param": "custom_value"}'
|
||||
|
||||
mock_response_all_params = MagicMock()
|
||||
mock_response_all_params.response = None
|
||||
mock_response_all_params.tool_calls = [mock_tool_call_all_params]
|
||||
mock_response_all_params.prompt_tokens = 50
|
||||
mock_response_all_params.completion_tokens = 25
|
||||
mock_response_all_params.reasoning = None
|
||||
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call", return_value=mock_response_all_params
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
# Should succeed with all parameters
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify all tool outputs were generated correctly
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
|
||||
|
||||
131
autogpt_platform/backend/backend/blocks/test/test_table_input.py
Normal file
131
autogpt_platform/backend/backend/blocks/test/test_table_input.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import pytest
|
||||
|
||||
from backend.blocks.io import AgentTableInputBlock
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_block():
|
||||
"""Test the AgentTableInputBlock with basic input/output."""
|
||||
block = AgentTableInputBlock()
|
||||
await execute_block_test(block)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_data():
|
||||
"""Test AgentTableInputBlock with actual table data."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="test_table",
|
||||
column_headers=["Name", "Age", "City"],
|
||||
value=[
|
||||
{"Name": "John", "Age": "30", "City": "New York"},
|
||||
{"Name": "Jane", "Age": "25", "City": "London"},
|
||||
{"Name": "Bob", "Age": "35", "City": "Paris"},
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 3
|
||||
assert result[0]["Name"] == "John"
|
||||
assert result[1]["Age"] == "25"
|
||||
assert result[2]["City"] == "Paris"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_empty_data():
|
||||
"""Test AgentTableInputBlock with empty data."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="empty_table", column_headers=["Col1", "Col2"], value=[]
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
assert output_data[0][1] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_missing_columns():
|
||||
"""Test AgentTableInputBlock passes through data with missing columns as-is."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="partial_table",
|
||||
column_headers=["Name", "Age", "City"],
|
||||
value=[
|
||||
{"Name": "John", "Age": "30"}, # Missing City
|
||||
{"Name": "Jane", "City": "London"}, # Missing Age
|
||||
{"Age": "35", "City": "Paris"}, # Missing Name
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 3
|
||||
|
||||
# Check data is passed through as-is
|
||||
assert result[0] == {"Name": "John", "Age": "30"}
|
||||
assert result[1] == {"Name": "Jane", "City": "London"}
|
||||
assert result[2] == {"Age": "35", "City": "Paris"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_none_value():
|
||||
"""Test AgentTableInputBlock with None value returns empty list."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
input_data = block.Input(
|
||||
name="none_table", column_headers=["Name", "Age"], value=None
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
assert output_data[0][1] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_input_with_default_headers():
|
||||
"""Test AgentTableInputBlock with default column headers."""
|
||||
block = AgentTableInputBlock()
|
||||
|
||||
# Don't specify column_headers, should use defaults
|
||||
input_data = block.Input(
|
||||
name="default_headers_table",
|
||||
value=[
|
||||
{"Column 1": "A", "Column 2": "B", "Column 3": "C"},
|
||||
{"Column 1": "D", "Column 2": "E", "Column 3": "F"},
|
||||
],
|
||||
)
|
||||
|
||||
output_data = []
|
||||
async for output_name, output_value in block.run(input_data):
|
||||
output_data.append((output_name, output_value))
|
||||
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0][0] == "result"
|
||||
|
||||
result = output_data[0][1]
|
||||
assert len(result) == 2
|
||||
assert result[0]["Column 1"] == "A"
|
||||
assert result[1]["Column 3"] == "F"
|
||||
@@ -172,6 +172,11 @@ class FillTextTemplateBlock(Block):
|
||||
format: str = SchemaField(
|
||||
description="Template to format the text using `values`. Use Jinja2 syntax."
|
||||
)
|
||||
escape_html: bool = SchemaField(
|
||||
default=False,
|
||||
advanced=True,
|
||||
description="Whether to escape special characters in the inserted values to be HTML-safe. Enable for HTML output, disable for plain text.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str = SchemaField(description="Formatted text")
|
||||
@@ -205,6 +210,7 @@ class FillTextTemplateBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
formatter = text.TextFormatter(autoescape=input_data.escape_html)
|
||||
yield "output", formatter.format_string(input_data.format, input_data.values)
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -178,9 +179,13 @@ async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
async def list_user_api_keys(
|
||||
user_id: str, limit: int = MAX_USER_API_KEYS_FETCH
|
||||
) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@@ -21,6 +20,7 @@ from typing import (
|
||||
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.models import AgentBlock
|
||||
from prisma.types import AgentBlockCreateInput
|
||||
from pydantic import BaseModel
|
||||
@@ -722,7 +722,7 @@ def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
|
||||
return cls() if cls else None
|
||||
|
||||
|
||||
@functools.cache
|
||||
@cached()
|
||||
def get_webhook_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
@@ -731,7 +731,7 @@ def get_webhook_block_ids() -> Sequence[str]:
|
||||
]
|
||||
|
||||
|
||||
@functools.cache
|
||||
@cached()
|
||||
def get_io_block_ids() -> Sequence[str]:
|
||||
return [
|
||||
id
|
||||
|
||||
@@ -69,6 +69,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
|
||||
@@ -23,6 +23,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -905,7 +906,9 @@ class UserCredit(UserCreditBase):
|
||||
),
|
||||
)
|
||||
|
||||
async def get_refund_requests(self, user_id: str) -> list[RefundRequest]:
|
||||
async def get_refund_requests(
|
||||
self, user_id: str, limit: int = MAX_CREDIT_REFUND_REQUESTS_FETCH
|
||||
) -> list[RefundRequest]:
|
||||
return [
|
||||
RefundRequest(
|
||||
id=r.id,
|
||||
@@ -921,6 +924,7 @@ class UserCredit(UserCreditBase):
|
||||
for r in await CreditRefundRequest.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ async def disconnect():
|
||||
|
||||
|
||||
# Transaction timeout constant (in milliseconds)
|
||||
TRANSACTION_TIMEOUT = 15000 # 15 seconds - Increased from 5s to prevent timeout errors
|
||||
TRANSACTION_TIMEOUT = 30000 # 30 seconds - Increased from 15s to prevent timeout errors during graph creation under load
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -20,6 +20,7 @@ from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsFieldInfo,
|
||||
@@ -29,6 +30,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.json import SafeJson
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
@@ -746,6 +748,13 @@ class GraphMeta(Graph):
|
||||
return GraphMeta(**graph.model_dump())
|
||||
|
||||
|
||||
class GraphsPaginated(BaseModel):
|
||||
"""Response schema for paginated graphs."""
|
||||
|
||||
graphs: list[GraphMeta]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
|
||||
@@ -774,31 +783,42 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
return NodeModel.from_db(node)
|
||||
|
||||
|
||||
async def list_graphs(
|
||||
async def list_graphs_paginated(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
filter_by: Literal["active"] | None = "active",
|
||||
) -> list[GraphMeta]:
|
||||
) -> GraphsPaginated:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
Retrieves paginated graph metadata objects.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user that owns the graphs.
|
||||
page: Page number (1-based).
|
||||
page_size: Number of graphs per page.
|
||||
filter_by: An optional filter to either select graphs.
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphMeta]: A list of objects representing the retrieved graphs.
|
||||
GraphsPaginated: Paginated list of graph metadata.
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
||||
|
||||
if filter_by == "active":
|
||||
where_clause["isActive"] = True
|
||||
|
||||
# Get total count
|
||||
total_count = await AgentGraph.prisma().count(where=where_clause)
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
# Get paginated results
|
||||
offset = (page - 1) * page_size
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
distinct=["id"],
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
skip=offset,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
graph_models: list[GraphMeta] = []
|
||||
@@ -812,7 +832,15 @@ async def list_graphs(
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
continue
|
||||
|
||||
return graph_models
|
||||
return GraphsPaginated(
|
||||
graphs=graph_models,
|
||||
pagination=Pagination(
|
||||
total_items=total_count,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
|
||||
@@ -1032,11 +1060,14 @@ async def set_graph_active_version(graph_id: str, version: int, user_id: str) ->
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_all_versions(graph_id: str, user_id: str) -> list[GraphModel]:
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: str, limit: int = MAX_GRAPH_VERSIONS_FETCH
|
||||
) -> list[GraphModel]:
|
||||
graph_versions = await AgentGraph.prisma().find_many(
|
||||
where={"id": graph_id, "userId": user_id},
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
take=limit,
|
||||
)
|
||||
|
||||
if not graph_versions:
|
||||
|
||||
@@ -14,6 +14,7 @@ AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"Nodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
EXECUTION_RESULT_ORDER: list[prisma.types.AgentNodeExecutionOrderByInput] = [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
@@ -28,6 +29,13 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
MAX_LIBRARY_AGENT_EXECUTIONS_FETCH = 10
|
||||
|
||||
# Default limits for potentially large result sets
|
||||
MAX_CREDIT_REFUND_REQUESTS_FETCH = 100
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH = 100
|
||||
MAX_USER_API_KEYS_FETCH = 500
|
||||
MAX_GRAPH_VERSIONS_FETCH = 50
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"NodeExecutions": {
|
||||
@@ -71,13 +79,56 @@ INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"AgentGraph": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"Executions": {"where": {"userId": user_id}},
|
||||
}
|
||||
},
|
||||
"Creator": True,
|
||||
def library_agent_include(
|
||||
user_id: str,
|
||||
include_nodes: bool = True,
|
||||
include_executions: bool = True,
|
||||
execution_limit: int = MAX_LIBRARY_AGENT_EXECUTIONS_FETCH,
|
||||
) -> prisma.types.LibraryAgentInclude:
|
||||
"""
|
||||
Fully configurable includes for library agent queries with performance optimization.
|
||||
|
||||
Args:
|
||||
user_id: User ID for filtering user-specific data
|
||||
include_nodes: Whether to include graph nodes (default: True, needed for get_sub_graphs)
|
||||
include_executions: Whether to include executions (default: True, safe with execution_limit)
|
||||
execution_limit: Limit on executions to fetch (default: MAX_LIBRARY_AGENT_EXECUTIONS_FETCH)
|
||||
|
||||
Defaults maintain backward compatibility and safety - includes everything needed for all functionality.
|
||||
For performance optimization, explicitly set include_nodes=False and include_executions=False
|
||||
for listing views where frontend fetches data separately.
|
||||
|
||||
Performance impact:
|
||||
- Default (full nodes + limited executions): Original performance, works everywhere
|
||||
- Listing optimization (no nodes/executions): ~2s for 15 agents vs potential timeouts
|
||||
- Unlimited executions: varies by user (thousands of executions = timeouts)
|
||||
"""
|
||||
result: prisma.types.LibraryAgentInclude = {
|
||||
"Creator": True, # Always needed for creator info
|
||||
}
|
||||
|
||||
# Build AgentGraph include based on requested options
|
||||
if include_nodes or include_executions:
|
||||
agent_graph_include = {}
|
||||
|
||||
# Add nodes if requested (always full nodes)
|
||||
if include_nodes:
|
||||
agent_graph_include.update(AGENT_GRAPH_INCLUDE) # Full nodes
|
||||
|
||||
# Add executions if requested
|
||||
if include_executions:
|
||||
agent_graph_include["Executions"] = {
|
||||
"where": {"userId": user_id},
|
||||
"order_by": {"createdAt": "desc"},
|
||||
"take": execution_limit,
|
||||
}
|
||||
|
||||
result["AgentGraph"] = cast(
|
||||
prisma.types.AgentGraphArgsFromLibraryAgent,
|
||||
{"include": agent_graph_include},
|
||||
)
|
||||
else:
|
||||
# Default: Basic metadata only (fast - recommended for most use cases)
|
||||
result["AgentGraph"] = True # Basic graph metadata (name, description, id)
|
||||
|
||||
return result
|
||||
|
||||
@@ -11,7 +11,10 @@ from prisma.types import (
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.includes import (
|
||||
INTEGRATION_WEBHOOK_INCLUDE,
|
||||
MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
@@ -128,22 +131,36 @@ async def get_webhook(
|
||||
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[True]
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[True],
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[WebhookWithRelations]: ...
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: Literal[False] = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[Webhook]: ...
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: bool = False
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
*,
|
||||
include_relations: bool = False,
|
||||
limit: int = MAX_INTEGRATION_WEBHOOKS_FETCH,
|
||||
) -> list[Webhook] | list[WebhookWithRelations]:
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"userId": user_id, "credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
)
|
||||
return [
|
||||
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
|
||||
@@ -270,6 +270,7 @@ def SchemaField(
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
json_schema_extra: Optional[dict[str, Any]] = None,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
@@ -285,6 +286,7 @@ def SchemaField(
|
||||
"advanced": advanced,
|
||||
"hidden": hidden,
|
||||
"depends_on": depends_on,
|
||||
"format": format,
|
||||
**(json_schema_extra or {}),
|
||||
}.items()
|
||||
if v is not None
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import prisma
|
||||
import pydantic
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from prisma.enums import OnboardingStep
|
||||
from prisma.models import UserOnboarding
|
||||
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||
@@ -30,7 +32,7 @@ user_credit = get_user_credit_model()
|
||||
|
||||
class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
completedSteps: Optional[list[OnboardingStep]] = None
|
||||
notificationDot: Optional[bool] = None
|
||||
walletShown: Optional[bool] = None
|
||||
notified: Optional[list[OnboardingStep]] = None
|
||||
usageReason: Optional[str] = None
|
||||
integrations: Optional[list[str]] = None
|
||||
@@ -39,6 +41,8 @@ class UserOnboardingUpdate(pydantic.BaseModel):
|
||||
agentInput: Optional[dict[str, Any]] = None
|
||||
onboardingAgentExecutionId: Optional[str] = None
|
||||
agentRuns: Optional[int] = None
|
||||
lastRunAt: Optional[datetime] = None
|
||||
consecutiveRunDays: Optional[int] = None
|
||||
|
||||
|
||||
async def get_user_onboarding(user_id: str):
|
||||
@@ -57,16 +61,22 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["completedSteps"] = list(set(data.completedSteps))
|
||||
for step in (
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.RUN_AGENTS,
|
||||
OnboardingStep.MARKETPLACE_VISIT,
|
||||
OnboardingStep.MARKETPLACE_ADD_AGENT,
|
||||
OnboardingStep.MARKETPLACE_RUN_AGENT,
|
||||
OnboardingStep.BUILDER_SAVE_AGENT,
|
||||
OnboardingStep.BUILDER_RUN_AGENT,
|
||||
OnboardingStep.RE_RUN_AGENT,
|
||||
OnboardingStep.SCHEDULE_AGENT,
|
||||
OnboardingStep.RUN_AGENTS,
|
||||
OnboardingStep.RUN_3_DAYS,
|
||||
OnboardingStep.TRIGGER_WEBHOOK,
|
||||
OnboardingStep.RUN_14_DAYS,
|
||||
OnboardingStep.RUN_AGENTS_100,
|
||||
):
|
||||
if step in data.completedSteps:
|
||||
await reward_user(user_id, step)
|
||||
if data.notificationDot is not None:
|
||||
update["notificationDot"] = data.notificationDot
|
||||
if data.walletShown is not None:
|
||||
update["walletShown"] = data.walletShown
|
||||
if data.notified is not None:
|
||||
update["notified"] = list(set(data.notified))
|
||||
if data.usageReason is not None:
|
||||
@@ -83,6 +93,10 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
|
||||
if data.agentRuns is not None:
|
||||
update["agentRuns"] = data.agentRuns
|
||||
if data.lastRunAt is not None:
|
||||
update["lastRunAt"] = data.lastRunAt
|
||||
if data.consecutiveRunDays is not None:
|
||||
update["consecutiveRunDays"] = data.consecutiveRunDays
|
||||
|
||||
return await UserOnboarding.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
@@ -101,16 +115,28 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
# This is seen as a reward for the GET_RESULTS step in the wallet
|
||||
case OnboardingStep.AGENT_NEW_RUN:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_AGENTS:
|
||||
reward = 300
|
||||
case OnboardingStep.MARKETPLACE_VISIT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_ADD_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.MARKETPLACE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_SAVE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.BUILDER_RUN_AGENT:
|
||||
case OnboardingStep.RE_RUN_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.SCHEDULE_AGENT:
|
||||
reward = 100
|
||||
case OnboardingStep.RUN_AGENTS:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_3_DAYS:
|
||||
reward = 100
|
||||
case OnboardingStep.TRIGGER_WEBHOOK:
|
||||
reward = 100
|
||||
case OnboardingStep.RUN_14_DAYS:
|
||||
reward = 300
|
||||
case OnboardingStep.RUN_AGENTS_100:
|
||||
reward = 300
|
||||
|
||||
if reward == 0:
|
||||
return
|
||||
@@ -132,6 +158,22 @@ async def reward_user(user_id: str, step: OnboardingStep):
|
||||
)
|
||||
|
||||
|
||||
async def complete_webhook_trigger_step(user_id: str):
|
||||
"""
|
||||
Completes the TRIGGER_WEBHOOK onboarding step for the user if not already completed.
|
||||
"""
|
||||
|
||||
onboarding = await get_user_onboarding(user_id)
|
||||
if OnboardingStep.TRIGGER_WEBHOOK not in onboarding.completedSteps:
|
||||
await update_user_onboarding(
|
||||
user_id,
|
||||
UserOnboardingUpdate(
|
||||
completedSteps=onboarding.completedSteps
|
||||
+ [OnboardingStep.TRIGGER_WEBHOOK]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def clean_and_split(text: str) -> list[str]:
|
||||
"""
|
||||
Removes all special characters from a string, truncates it to 100 characters,
|
||||
@@ -333,8 +375,13 @@ async def get_recommended_agents(user_id: str) -> list[StoreAgentDetails]:
|
||||
]
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl_seconds=300) # Cache for 5 minutes since this rarely changes
|
||||
async def onboarding_enabled() -> bool:
|
||||
"""
|
||||
Check if onboarding should be enabled based on store agent count.
|
||||
Cached to prevent repeated slow database queries.
|
||||
"""
|
||||
# Use a more efficient query that stops counting after finding enough agents
|
||||
count = await prisma.models.StoreAgent.prisma().count(take=MIN_AGENT_COUNT + 1)
|
||||
|
||||
# Onboading is enabled if there are at least 2 agents in the store
|
||||
# Onboarding is enabled if there are at least 2 agents in the store
|
||||
return count >= MIN_AGENT_COUNT
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import cache
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.cache import cached, thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
@@ -13,7 +12,7 @@ load_dotenv()
|
||||
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +34,7 @@ def disconnect():
|
||||
get_redis().close()
|
||||
|
||||
|
||||
@cache
|
||||
@cached()
|
||||
def get_redis() -> Redis:
|
||||
return connect()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional, cast
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import HTTPException
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User as PrismaUser
|
||||
@@ -23,7 +24,11 @@ from backend.util.settings import Settings
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Cache decorator alias for consistent user lookup caching
|
||||
cache_user_lookup = cached(maxsize=1000, ttl_seconds=300)
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_or_create_user(user_data: dict) -> User:
|
||||
try:
|
||||
user_id = user_data.get("sub")
|
||||
@@ -49,6 +54,7 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_user_by_id(user_id: str) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
@@ -64,6 +70,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
raise DatabaseError(f"Failed to get user email for user {user_id}: {e}") from e
|
||||
|
||||
|
||||
@cache_user_lookup
|
||||
async def get_user_by_email(email: str) -> Optional[User]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"email": email})
|
||||
@@ -74,7 +81,17 @@ async def get_user_by_email(email: str) -> Optional[User]:
|
||||
|
||||
async def update_user_email(user_id: str, email: str):
|
||||
try:
|
||||
# Get old email first for cache invalidation
|
||||
old_user = await prisma.user.find_unique(where={"id": user_id})
|
||||
old_email = old_user.email if old_user else None
|
||||
|
||||
await prisma.user.update(where={"id": user_id}, data={"email": email})
|
||||
|
||||
# Selectively invalidate only the specific user entries
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
if old_email:
|
||||
get_user_by_email.cache_delete(old_email)
|
||||
get_user_by_email.cache_delete(email)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to update user email for user {user_id}: {e}"
|
||||
@@ -114,6 +131,8 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
where={"id": user_id},
|
||||
data={"integrations": encrypted_data},
|
||||
)
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
|
||||
|
||||
async def migrate_and_encrypt_user_integrations():
|
||||
@@ -285,6 +304,10 @@ async def update_user_notification_preference(
|
||||
)
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
|
||||
# Invalidate cache for this user since notification preferences are part of user data
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
|
||||
preferences: dict[NotificationType, bool] = {
|
||||
NotificationType.AGENT_RUN: user.notifyOnAgentRun or True,
|
||||
NotificationType.ZERO_BALANCE: user.notifyOnZeroBalance or True,
|
||||
@@ -323,6 +346,8 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
where={"id": user_id},
|
||||
data={"emailVerified": verified},
|
||||
)
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to set email verification status for user {user_id}: {e}"
|
||||
@@ -407,6 +432,10 @@ async def update_user_timezone(user_id: str, timezone: str) -> User:
|
||||
)
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
|
||||
# Invalidate cache for this user
|
||||
get_user_by_id.cache_delete(user_id)
|
||||
|
||||
return User.from_db(user)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to update timezone for user {user_id}: {e}") from e
|
||||
|
||||
@@ -423,7 +423,6 @@ async def _call_llm_direct(
|
||||
credentials=credentials,
|
||||
llm_model=LlmModel.GPT4O_MINI,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=150,
|
||||
compress_prompt_to_fit=True,
|
||||
)
|
||||
|
||||
115
autogpt_platform/backend/backend/executor/cluster_lock.py
Normal file
115
autogpt_platform/backend/backend/executor/cluster_lock.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClusterLock:
|
||||
"""Simple Redis-based distributed lock for preventing duplicate execution."""
|
||||
|
||||
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
|
||||
self.redis = redis
|
||||
self.key = key
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
|
||||
def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
|
||||
Returns:
|
||||
- owner_id (self.owner_id) if successfully acquired
|
||||
- different owner_id if someone else holds the lock
|
||||
- None if Redis is unavailable or other error
|
||||
"""
|
||||
try:
|
||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||
if success:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
current_value = self.redis.get(self.key)
|
||||
if current_value:
|
||||
current_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
return current_owner
|
||||
|
||||
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.try_acquire failed for key {self.key}: {e}")
|
||||
return None
|
||||
|
||||
def refresh(self) -> bool:
|
||||
"""Refresh lock TTL if we still own it.
|
||||
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
is_rate_limited = (
|
||||
self._last_refresh > 0
|
||||
and (current_time - self._last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = self.redis.get(self.key)
|
||||
if not current_value:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
current_value.decode("utf-8")
|
||||
if isinstance(current_value, bytes)
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
if is_rate_limited:
|
||||
return True
|
||||
|
||||
# Perform actual refresh
|
||||
if self.redis.expire(self.key, self.timeout):
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release the lock."""
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._last_refresh = 0.0
|
||||
507
autogpt_platform/backend/backend/executor/cluster_lock_test.py
Normal file
507
autogpt_platform/backend/backend/executor/cluster_lock_test.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Integration tests for ClusterLock - Redis-based distributed locking.
|
||||
|
||||
Tests the complete lock lifecycle without mocking Redis to ensure
|
||||
real-world behavior is correct. Covers acquisition, refresh, expiry,
|
||||
contention, and error scenarios.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from threading import Thread
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_client():
|
||||
"""Get Redis client for testing using same config as backend."""
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
|
||||
# Use same config as backend but without decode_responses since ClusterLock needs raw bytes
|
||||
client = redis.Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=False, # ClusterLock needs raw bytes for ownership verification
|
||||
)
|
||||
|
||||
# Clean up any existing test keys
|
||||
try:
|
||||
for key in client.scan_iter(match="test_lock:*"):
|
||||
client.delete(key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lock_key():
|
||||
"""Generate unique lock key for each test."""
|
||||
return f"test_lock:{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def owner_id():
|
||||
"""Generate unique owner ID for each test."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestClusterLockBasic:
|
||||
"""Basic lock acquisition and release functionality."""
|
||||
|
||||
def test_lock_acquisition_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test basic lock acquisition succeeds."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Lock should be acquired successfully
|
||||
result = lock.try_acquire()
|
||||
assert result == owner_id # Returns our owner_id when successfully acquired
|
||||
assert lock._last_refresh > 0
|
||||
|
||||
# Lock key should exist in Redis
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
assert redis_client.get(lock_key).decode("utf-8") == owner_id
|
||||
|
||||
def test_lock_acquisition_contention(self, redis_client, lock_key):
|
||||
"""Test second acquisition fails when lock is held."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=60)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=60)
|
||||
|
||||
# First lock should succeed
|
||||
result1 = lock1.try_acquire()
|
||||
assert result1 == owner1 # Successfully acquired, returns our owner_id
|
||||
|
||||
# Second lock should fail and return the first owner
|
||||
result2 = lock2.try_acquire()
|
||||
assert result2 == owner1 # Returns the current owner (first owner)
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
def test_lock_release_deletes_redis_key(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock release deletes Redis key and marks locally as released."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
assert lock._last_refresh > 0
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Release should delete Redis key and mark locally as released
|
||||
lock.release()
|
||||
assert lock._last_refresh == 0
|
||||
assert lock._last_refresh == 0.0
|
||||
|
||||
# Redis key should be deleted for immediate release
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# Another lock should be able to acquire immediately
|
||||
new_owner_id = str(uuid.uuid4())
|
||||
new_lock = ClusterLock(redis_client, lock_key, new_owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == new_owner_id
|
||||
|
||||
|
||||
class TestClusterLockRefresh:
|
||||
"""Lock refresh and TTL management."""
|
||||
|
||||
def test_lock_refresh_success(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock refresh extends TTL."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
original_ttl = redis_client.ttl(lock_key)
|
||||
|
||||
# Wait a bit then refresh
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# TTL should be reset to full timeout (allow for small timing differences)
|
||||
new_ttl = redis_client.ttl(lock_key)
|
||||
assert new_ttl >= original_ttl or new_ttl >= 58 # Allow for timing variance
|
||||
|
||||
def test_lock_refresh_rate_limiting(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh is rate-limited to timeout/10."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=100
|
||||
) # 100s timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# First refresh should work
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Immediate second refresh should be skipped (rate limited) but verify key exists
|
||||
assert lock.refresh() is True # Returns True but skips actual refresh
|
||||
assert lock._last_refresh == first_refresh_time # Time unchanged
|
||||
|
||||
def test_lock_refresh_verifies_existence_during_rate_limit(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test refresh verifies lock existence even during rate limiting."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=100)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates expiry or external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should detect missing key even during rate limit period
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_ownership_lost(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when ownership is lost."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Simulate another process taking the lock
|
||||
different_owner = str(uuid.uuid4())
|
||||
redis_client.set(lock_key, different_owner, ex=60)
|
||||
|
||||
# Force refresh past rate limit and verify it fails
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_lock_refresh_when_not_acquired(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh fails when lock was never acquired."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Refresh without acquiring should fail
|
||||
assert lock.refresh() is False
|
||||
|
||||
|
||||
class TestClusterLockExpiry:
|
||||
"""Lock expiry and timeout behavior."""
|
||||
|
||||
def test_lock_natural_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test lock expires naturally via Redis TTL."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=2
|
||||
) # 2 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
assert redis_client.exists(lock_key) == 1
|
||||
|
||||
# Wait for expiry
|
||||
time.sleep(3)
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
# New lock with same key should succeed
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
def test_lock_refresh_prevents_expiry(self, redis_client, lock_key, owner_id):
|
||||
"""Test refreshing prevents lock from expiring."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # 3 second timeout
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Wait and refresh before expiry
|
||||
time.sleep(1)
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Wait beyond original timeout
|
||||
time.sleep(2.5)
|
||||
assert redis_client.exists(lock_key) == 1 # Should still exist
|
||||
|
||||
|
||||
class TestClusterLockConcurrency:
|
||||
"""Concurrent access patterns."""
|
||||
|
||||
def test_multiple_threads_contention(self, redis_client, lock_key):
|
||||
"""Test multiple threads competing for same lock."""
|
||||
num_threads = 5
|
||||
successful_acquisitions = []
|
||||
|
||||
def try_acquire_lock(thread_id):
|
||||
owner_id = f"thread_{thread_id}"
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
if lock.try_acquire() == owner_id:
|
||||
successful_acquisitions.append(thread_id)
|
||||
time.sleep(0.1) # Hold lock briefly
|
||||
lock.release()
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
thread = Thread(target=try_acquire_lock, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one thread should have acquired the lock
|
||||
assert len(successful_acquisitions) == 1
|
||||
|
||||
def test_sequential_lock_reuse(self, redis_client, lock_key):
|
||||
"""Test lock can be reused after natural expiry."""
|
||||
owners = [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
for i, owner_id in enumerate(owners):
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=1) # 1 second
|
||||
|
||||
assert lock.try_acquire() == owner_id
|
||||
time.sleep(1.5) # Wait for expiry
|
||||
|
||||
# Verify lock expired
|
||||
assert redis_client.exists(lock_key) == 0
|
||||
|
||||
def test_refresh_during_concurrent_access(self, redis_client, lock_key):
|
||||
"""Test lock refresh works correctly during concurrent access attempts."""
|
||||
owner1 = str(uuid.uuid4())
|
||||
owner2 = str(uuid.uuid4())
|
||||
|
||||
lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=5)
|
||||
lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=5)
|
||||
|
||||
# Thread 1 holds lock and refreshes
|
||||
assert lock1.try_acquire() == owner1
|
||||
|
||||
def refresh_continuously():
|
||||
for _ in range(10):
|
||||
lock1._last_refresh = 0 # Force refresh
|
||||
lock1.refresh()
|
||||
time.sleep(0.1)
|
||||
|
||||
def try_acquire_continuously():
|
||||
attempts = 0
|
||||
while attempts < 20:
|
||||
if lock2.try_acquire() == owner2:
|
||||
return True
|
||||
time.sleep(0.1)
|
||||
attempts += 1
|
||||
return False
|
||||
|
||||
refresh_thread = Thread(target=refresh_continuously)
|
||||
acquire_thread = Thread(target=try_acquire_continuously)
|
||||
|
||||
refresh_thread.start()
|
||||
acquire_thread.start()
|
||||
|
||||
refresh_thread.join()
|
||||
acquire_thread.join()
|
||||
|
||||
# Lock1 should still own the lock due to refreshes
|
||||
assert lock1._last_refresh > 0
|
||||
assert lock2._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockErrorHandling:
|
||||
"""Error handling and edge cases."""
|
||||
|
||||
def test_redis_connection_failure_on_acquire(self, lock_key, owner_id):
|
||||
"""Test graceful handling when Redis is unavailable during acquisition."""
|
||||
# Use invalid Redis connection
|
||||
bad_redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
lock = ClusterLock(bad_redis, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Should return None for Redis connection failures
|
||||
result = lock.try_acquire()
|
||||
assert result is None # Returns None when Redis fails
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_redis_connection_failure_on_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test graceful handling when Redis fails during refresh."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
# Acquire normally
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Replace Redis client with failing one
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host", port=1234, socket_connect_timeout=1
|
||||
)
|
||||
|
||||
# Refresh should fail gracefully
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
def test_invalid_lock_parameters(self, redis_client):
|
||||
"""Test validation of lock parameters."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
|
||||
# All parameters are now simple - no validation needed
|
||||
# Just test basic construction works
|
||||
lock = ClusterLock(redis_client, "test_key", owner_id, timeout=60)
|
||||
assert lock.key == "test_key"
|
||||
assert lock.owner_id == owner_id
|
||||
assert lock.timeout == 60
|
||||
|
||||
def test_refresh_after_redis_key_deleted(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh behavior when Redis key is manually deleted."""
|
||||
lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
|
||||
lock.try_acquire()
|
||||
|
||||
# Manually delete the key (simulates external deletion)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
# Refresh should fail and mark as not acquired
|
||||
lock._last_refresh = 0 # Force refresh
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
|
||||
class TestClusterLockDynamicRefreshInterval:
|
||||
"""Dynamic refresh interval based on timeout."""
|
||||
|
||||
def test_refresh_interval_calculation(self, redis_client, lock_key, owner_id):
|
||||
"""Test refresh interval is calculated as max(timeout/10, 1)."""
|
||||
test_cases = [
|
||||
(5, 1), # 5/10 = 0, but minimum is 1
|
||||
(10, 1), # 10/10 = 1
|
||||
(30, 3), # 30/10 = 3
|
||||
(100, 10), # 100/10 = 10
|
||||
(200, 20), # 200/10 = 20
|
||||
(1000, 100), # 1000/10 = 100
|
||||
]
|
||||
|
||||
for timeout, expected_interval in test_cases:
|
||||
lock = ClusterLock(
|
||||
redis_client, f"{lock_key}_{timeout}", owner_id, timeout=timeout
|
||||
)
|
||||
lock.try_acquire()
|
||||
|
||||
# Calculate expected interval using same logic as implementation
|
||||
refresh_interval = max(timeout // 10, 1)
|
||||
assert refresh_interval == expected_interval
|
||||
|
||||
# Test rate limiting works with calculated interval
|
||||
assert lock.refresh() is True
|
||||
first_refresh_time = lock._last_refresh
|
||||
|
||||
# Sleep less than interval - should be rate limited
|
||||
time.sleep(0.1)
|
||||
assert lock.refresh() is True
|
||||
assert lock._last_refresh == first_refresh_time # No actual refresh
|
||||
|
||||
|
||||
class TestClusterLockRealWorldScenarios:
|
||||
"""Real-world usage patterns."""
|
||||
|
||||
def test_execution_coordination_simulation(self, redis_client):
|
||||
"""Simulate graph execution coordination across multiple pods."""
|
||||
graph_exec_id = str(uuid.uuid4())
|
||||
lock_key = f"execution:{graph_exec_id}"
|
||||
|
||||
# Simulate 3 pods trying to execute same graph
|
||||
pods = [f"pod_{i}" for i in range(3)]
|
||||
execution_results = {}
|
||||
|
||||
def execute_graph(pod_id):
|
||||
"""Simulate graph execution with cluster lock."""
|
||||
lock = ClusterLock(redis_client, lock_key, pod_id, timeout=300)
|
||||
|
||||
if lock.try_acquire() == pod_id:
|
||||
# Simulate execution work
|
||||
execution_results[pod_id] = "executed"
|
||||
time.sleep(0.1)
|
||||
lock.release()
|
||||
else:
|
||||
execution_results[pod_id] = "rejected"
|
||||
|
||||
threads = []
|
||||
for pod_id in pods:
|
||||
thread = Thread(target=execute_graph, args=(pod_id,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Only one pod should have executed
|
||||
executed_count = sum(
|
||||
1 for result in execution_results.values() if result == "executed"
|
||||
)
|
||||
rejected_count = sum(
|
||||
1 for result in execution_results.values() if result == "rejected"
|
||||
)
|
||||
|
||||
assert executed_count == 1
|
||||
assert rejected_count == 2
|
||||
|
||||
def test_long_running_execution_with_refresh(
|
||||
self, redis_client, lock_key, owner_id
|
||||
):
|
||||
"""Test lock maintains ownership during long execution with periodic refresh."""
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=30
|
||||
) # 30 second timeout, refresh interval = max(30//10, 1) = 3 seconds
|
||||
|
||||
def long_execution_with_refresh():
|
||||
"""Simulate long-running execution with periodic refresh."""
|
||||
assert lock.try_acquire() == owner_id
|
||||
|
||||
# Simulate 10 seconds of work with refreshes every 2 seconds
|
||||
# This respects rate limiting - actual refreshes will happen at 0s, 3s, 6s, 9s
|
||||
try:
|
||||
for i in range(5): # 5 iterations * 2 seconds = 10 seconds total
|
||||
time.sleep(2)
|
||||
refresh_success = lock.refresh()
|
||||
assert refresh_success is True, f"Refresh failed at iteration {i}"
|
||||
return "completed"
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
# Should complete successfully without losing lock
|
||||
result = long_execution_with_refresh()
|
||||
assert result == "completed"
|
||||
|
||||
def test_graceful_degradation_pattern(self, redis_client, lock_key):
|
||||
"""Test graceful degradation when Redis becomes unavailable."""
|
||||
owner_id = str(uuid.uuid4())
|
||||
lock = ClusterLock(
|
||||
redis_client, lock_key, owner_id, timeout=3
|
||||
) # Use shorter timeout
|
||||
|
||||
# Normal operation
|
||||
assert lock.try_acquire() == owner_id
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is True
|
||||
|
||||
# Simulate Redis becoming unavailable
|
||||
original_redis = lock.redis
|
||||
lock.redis = redis.Redis(
|
||||
host="invalid_host",
|
||||
port=1234,
|
||||
socket_connect_timeout=1,
|
||||
decode_responses=False,
|
||||
)
|
||||
|
||||
# Should degrade gracefully
|
||||
lock._last_refresh = 0 # Force refresh past rate limit
|
||||
assert lock.refresh() is False
|
||||
assert lock._last_refresh == 0
|
||||
|
||||
# Restore Redis and verify can acquire again
|
||||
lock.redis = original_redis
|
||||
# Wait for original lock to expire (use longer wait for 3s timeout)
|
||||
time.sleep(4)
|
||||
|
||||
new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60)
|
||||
assert new_lock.try_acquire() == owner_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run specific test for quick validation
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -85,6 +85,16 @@ class DatabaseManager(AppService):
|
||||
async def health_check(self) -> str:
|
||||
if not db.is_connected():
|
||||
raise UnhealthyServiceError("Database is not connected")
|
||||
|
||||
try:
|
||||
# Test actual database connectivity by executing a simple query
|
||||
# This will fail if Prisma query engine is not responding
|
||||
result = await db.query_raw_with_schema("SELECT 1 as health_check")
|
||||
if not result or result[0].get("health_check") != 1:
|
||||
raise UnhealthyServiceError("Database query test failed")
|
||||
except Exception as e:
|
||||
raise UnhealthyServiceError(f"Database health check failed: {e}")
|
||||
|
||||
return await super().health_check()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -10,31 +11,11 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import LogMetadata
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient
|
||||
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockInput,
|
||||
@@ -55,12 +36,25 @@ from backend.data.execution import (
|
||||
UserContext,
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.activity_status_generator import (
|
||||
generate_activity_status_for_execution,
|
||||
)
|
||||
from backend.executor.utils import (
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
GRAPH_EXECUTION_QUEUE_NAME,
|
||||
CancelExecutionEvent,
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
@@ -69,6 +63,7 @@ from backend.executor.utils import (
|
||||
validate_exec,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.server.v2.AutoMod.manager import automod_manager
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
@@ -84,6 +79,7 @@ from backend.util.decorator import (
|
||||
error_logged,
|
||||
time_measured,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
@@ -91,6 +87,12 @@ from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import continuous_retry, func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .cluster_lock import ClusterLock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = TruncatedLogger(_logger, prefix="[GraphExecutor]")
|
||||
settings = Settings()
|
||||
@@ -106,6 +108,7 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
|
||||
@@ -117,10 +120,14 @@ def init_worker():
|
||||
|
||||
|
||||
def execute_graph(
|
||||
graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event
|
||||
graph_exec_entry: "GraphExecutionEntry",
|
||||
cancel_event: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||
return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event)
|
||||
return _tls.processor.on_graph_execution(
|
||||
graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -429,7 +436,7 @@ class ExecutionProcessor:
|
||||
graph_id=node_exec.graph_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
block_name=b.name if (b := get_block(node_exec.block_id)) else "-",
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
@@ -583,6 +590,7 @@ class ExecutionProcessor:
|
||||
self,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -641,6 +649,7 @@ class ExecutionProcessor:
|
||||
cancel=cancel,
|
||||
log_metadata=log_metadata,
|
||||
execution_stats=exec_stats,
|
||||
cluster_lock=cluster_lock,
|
||||
)
|
||||
exec_stats.walltime += timing_info.wall_time
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
@@ -742,6 +751,7 @@ class ExecutionProcessor:
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
execution_stats: GraphExecutionStats,
|
||||
cluster_lock: ClusterLock,
|
||||
) -> ExecutionStatus:
|
||||
"""
|
||||
Returns:
|
||||
@@ -927,7 +937,7 @@ class ExecutionProcessor:
|
||||
and execution_queue.empty()
|
||||
and (running_node_execution or running_node_evaluation)
|
||||
):
|
||||
# There is nothing to execute, and no output to process, let's relax for a while.
|
||||
cluster_lock.refresh()
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
@@ -1219,6 +1229,7 @@ class ExecutionManager(AppProcess):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
@@ -1228,6 +1239,8 @@ class ExecutionManager(AppProcess):
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._execution_locks = {}
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
@@ -1435,17 +1448,46 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(
|
||||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
|
||||
)
|
||||
|
||||
# Check for local duplicate execution first
|
||||
if graph_exec_id in self.active_graph_runs:
|
||||
# TODO: Make this check cluster-wide, prevent duplicate runs across executor pods.
|
||||
logger.error(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate."
|
||||
)
|
||||
_ack_message(reject=True, requeue=False)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide execution lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"exec_lock:{graph_exec_id}",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
# Either someone else has it or Redis is unavailable
|
||||
if current_owner is not None:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
|
||||
)
|
||||
_ack_message(reject=True, requeue=True)
|
||||
return
|
||||
self._execution_locks[graph_exec_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"[{self.service_name}] Acquired cluster lock for {graph_exec_id} with executor {self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
|
||||
future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event)
|
||||
future = self.executor.submit(
|
||||
execute_graph, graph_exec_entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
self._update_prompt_metrics()
|
||||
|
||||
@@ -1464,6 +1506,10 @@ class ExecutionManager(AppProcess):
|
||||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||||
)
|
||||
finally:
|
||||
# Release the cluster-wide execution lock
|
||||
if graph_exec_id in self._execution_locks:
|
||||
self._execution_locks[graph_exec_id].release()
|
||||
del self._execution_locks[graph_exec_id]
|
||||
self._cleanup_completed_runs()
|
||||
|
||||
future.add_done_callback(_on_run_done)
|
||||
@@ -1546,6 +1592,10 @@ class ExecutionManager(AppProcess):
|
||||
f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}"
|
||||
)
|
||||
|
||||
for graph_exec_id in self.active_graph_runs:
|
||||
if lock := self._execution_locks.get(graph_exec_id):
|
||||
lock.refresh()
|
||||
|
||||
time.sleep(wait_interval)
|
||||
waited += wait_interval
|
||||
|
||||
@@ -1563,6 +1613,15 @@ class ExecutionManager(AppProcess):
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||||
|
||||
# Release remaining execution locks
|
||||
try:
|
||||
for lock in self._execution_locks.values():
|
||||
lock.release()
|
||||
self._execution_locks.clear()
|
||||
logger.info(f"{prefix} ✅ Released execution locks")
|
||||
except Exception as e:
|
||||
logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}")
|
||||
|
||||
# Disconnect the run execution consumer
|
||||
self._stop_message_consumers(
|
||||
self.run_thread,
|
||||
@@ -1668,15 +1727,18 @@ def update_graph_execution_state(
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def synchronized(key: str, timeout: int = 60):
|
||||
async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout):
|
||||
r = await redis.get_redis_async()
|
||||
lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
await lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if await lock.locked() and await lock.owned():
|
||||
await lock.release()
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release lock for key {key}: {e}")
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
|
||||
@@ -151,7 +151,10 @@ class IntegrationCredentialsManager:
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
await _lock.release()
|
||||
try:
|
||||
await _lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release OAuth refresh lock: {e}")
|
||||
|
||||
credentials = fresh_credentials
|
||||
return credentials
|
||||
@@ -184,7 +187,10 @@ class IntegrationCredentialsManager:
|
||||
yield
|
||||
finally:
|
||||
if (await lock.locked()) and (await lock.owned()):
|
||||
await lock.release()
|
||||
try:
|
||||
await lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release credentials lock: {e}")
|
||||
|
||||
async def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import functools
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
@functools.cache
|
||||
@cached()
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
webhook_managers = {}
|
||||
|
||||
|
||||
@@ -168,38 +168,45 @@ async def migrate_legacy_triggered_graphs():
|
||||
n_migrated_webhooks = 0
|
||||
|
||||
for graph in triggered_graphs:
|
||||
if not ((trigger_node := graph.webhook_input_node) and trigger_node.webhook_id):
|
||||
try:
|
||||
if not (
|
||||
(trigger_node := graph.webhook_input_node) and trigger_node.webhook_id
|
||||
):
|
||||
continue
|
||||
|
||||
# Use trigger node's inputs for the preset
|
||||
preset_credentials = {
|
||||
field_name: creds_meta
|
||||
for field_name, creds_meta in trigger_node.input_default.items()
|
||||
if is_credentials_field_name(field_name)
|
||||
}
|
||||
preset_inputs = {
|
||||
field_name: value
|
||||
for field_name, value in trigger_node.input_default.items()
|
||||
if not is_credentials_field_name(field_name)
|
||||
}
|
||||
|
||||
# Create a triggered preset for the graph
|
||||
await create_preset(
|
||||
graph.user_id,
|
||||
LibraryAgentPresetCreatable(
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
inputs=preset_inputs,
|
||||
credentials=preset_credentials,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
webhook_id=trigger_node.webhook_id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Detach webhook from the graph node
|
||||
await set_node_webhook(trigger_node.id, None)
|
||||
|
||||
n_migrated_webhooks += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate graph #{graph.id} trigger to preset: {e}")
|
||||
continue
|
||||
|
||||
# Use trigger node's inputs for the preset
|
||||
preset_credentials = {
|
||||
field_name: creds_meta
|
||||
for field_name, creds_meta in trigger_node.input_default.items()
|
||||
if is_credentials_field_name(field_name)
|
||||
}
|
||||
preset_inputs = {
|
||||
field_name: value
|
||||
for field_name, value in trigger_node.input_default.items()
|
||||
if not is_credentials_field_name(field_name)
|
||||
}
|
||||
|
||||
# Create a triggered preset for the graph
|
||||
await create_preset(
|
||||
graph.user_id,
|
||||
LibraryAgentPresetCreatable(
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
inputs=preset_inputs,
|
||||
credentials=preset_credentials,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
webhook_id=trigger_node.webhook_id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
# Detach webhook from the graph node
|
||||
await set_node_webhook(trigger_node.id, None)
|
||||
|
||||
n_migrated_webhooks += 1
|
||||
|
||||
logger.info(f"Migrated {n_migrated_webhooks} node triggers to triggered presets")
|
||||
|
||||
@@ -49,7 +49,7 @@ class GraphExecutionResult(TypedDict):
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||
)
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||
return [b.to_dict() for b in blocks if not b.disabled]
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from backend.data.model import (
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.onboarding import complete_webhook_trigger_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
@@ -367,6 +368,8 @@ async def webhook_ingress_generic(
|
||||
return
|
||||
|
||||
executions: list[Awaitable] = []
|
||||
await complete_webhook_trigger_step(user_id)
|
||||
|
||||
for node in webhook.triggered_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
class SecurityHeadersMiddleware:
|
||||
"""
|
||||
Middleware to add security headers to responses, with cache control
|
||||
disabled by default for all endpoints except those explicitly allowed.
|
||||
@@ -25,6 +23,8 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"/api/health",
|
||||
"/api/v1/health",
|
||||
"/api/status",
|
||||
"/api/blocks",
|
||||
"/api/v1/blocks",
|
||||
# Public store/marketplace pages (read-only)
|
||||
"/api/store/agents",
|
||||
"/api/v1/store/agents",
|
||||
@@ -49,7 +49,7 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
}
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
super().__init__(app)
|
||||
self.app = app
|
||||
# Compile regex patterns for wildcard matching
|
||||
self.cacheable_patterns = [
|
||||
re.compile(pattern.replace("*", "[^/]+"))
|
||||
@@ -72,26 +72,42 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response: Response = await call_next(request)
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Pure ASGI middleware implementation for better performance than BaseHTTPMiddleware."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Add general security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
# Extract path from scope
|
||||
path = scope["path"]
|
||||
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in request.url.path:
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow"
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
# Add security headers to the response
|
||||
headers = dict(message.get("headers", []))
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(request.url.path):
|
||||
response.headers["Cache-Control"] = (
|
||||
"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
# Add general security headers (HTTP spec requires proper capitalization)
|
||||
headers[b"X-Content-Type-Options"] = b"nosniff"
|
||||
headers[b"X-Frame-Options"] = b"DENY"
|
||||
headers[b"X-XSS-Protection"] = b"1; mode=block"
|
||||
headers[b"Referrer-Policy"] = b"strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in path:
|
||||
headers[b"X-Robots-Tag"] = b"noindex, nofollow"
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(path):
|
||||
headers[b"Cache-Control"] = (
|
||||
b"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
headers[b"Pragma"] = b"no-cache"
|
||||
headers[b"Expires"] = b"0"
|
||||
|
||||
# Convert headers back to list format
|
||||
message["headers"] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import platform
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -11,6 +12,7 @@ import uvicorn
|
||||
from autogpt_libs.auth import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth import verify_settings as verify_auth_settings
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
@@ -70,6 +72,26 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
await backend.data.db.connect()
|
||||
|
||||
# Configure thread pool for FastAPI sync operation performance
|
||||
# CRITICAL: FastAPI automatically runs ALL sync functions in this thread pool:
|
||||
# - Any endpoint defined with 'def' (not async def)
|
||||
# - Any dependency function defined with 'def' (not async def)
|
||||
# - Manual run_in_threadpool() calls (like JWT decoding)
|
||||
# Default pool size is only 40 threads, causing bottlenecks under high concurrency
|
||||
config = backend.util.settings.Config()
|
||||
try:
|
||||
import anyio.to_thread
|
||||
|
||||
anyio.to_thread.current_default_thread_limiter().total_tokens = (
|
||||
config.fastapi_thread_pool_size
|
||||
)
|
||||
logger.info(
|
||||
f"Thread pool size set to {config.fastapi_thread_pool_size} for sync endpoint/dependency performance"
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(f"Could not configure thread pool size: {e}")
|
||||
# Continue without thread pool configuration
|
||||
|
||||
# Ensure SDK auto-registration is patched before initializing blocks
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
@@ -140,6 +162,9 @@ app = fastapi.FastAPI(
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
# Add GZip compression middleware for large responses (like /api/blocks)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=50_000) # 50KB threshold
|
||||
|
||||
# Add 401 responses to authenticated endpoints in OpenAPI spec
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
@@ -273,12 +298,28 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
uvicorn.run(
|
||||
server_app,
|
||||
host=backend.util.settings.Config().agent_api_host,
|
||||
port=backend.util.settings.Config().agent_api_port,
|
||||
log_config=None,
|
||||
)
|
||||
config = backend.util.settings.Config()
|
||||
|
||||
# Configure uvicorn with performance optimizations from Kludex FastAPI tips
|
||||
uvicorn_config = {
|
||||
"app": server_app,
|
||||
"host": config.agent_api_host,
|
||||
"port": config.agent_api_port,
|
||||
"log_config": None,
|
||||
# Use httptools for HTTP parsing (if available)
|
||||
"http": "httptools",
|
||||
# Only use uvloop on Unix-like systems (not supported on Windows)
|
||||
"loop": "uvloop" if platform.system() != "Windows" else "auto",
|
||||
}
|
||||
|
||||
# Only add debug in local environment (not supported in all uvicorn versions)
|
||||
if config.app_env == backend.util.settings.AppEnvironment.LOCAL:
|
||||
import os
|
||||
|
||||
# Enable asyncio debug mode via environment variable
|
||||
os.environ["PYTHONASYNCIODEBUG"] = "1"
|
||||
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
|
||||
@@ -11,6 +11,7 @@ import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -23,6 +24,8 @@ from fastapi import (
|
||||
Security,
|
||||
UploadFile,
|
||||
)
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
@@ -38,10 +41,10 @@ from backend.data.credit import (
|
||||
RefundRequest,
|
||||
TransactionHistory,
|
||||
get_auto_top_up,
|
||||
get_block_costs,
|
||||
get_user_credit_model,
|
||||
set_auto_top_up,
|
||||
)
|
||||
from backend.data.execution import UserContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.data.onboarding import (
|
||||
@@ -84,6 +87,7 @@ from backend.server.model import (
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.json import dumps
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import (
|
||||
convert_utc_time_to_user_timezone,
|
||||
@@ -262,18 +266,69 @@ async def is_onboarding_enabled():
|
||||
########################################################
|
||||
|
||||
|
||||
def _compute_blocks_sync() -> str:
|
||||
"""
|
||||
Synchronous function to compute blocks data.
|
||||
This does the heavy lifting: instantiate 226+ blocks, compute costs, serialize.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
block_classes = get_blocks()
|
||||
result = []
|
||||
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
costs = get_block_cost(block_instance)
|
||||
# Convert BlockCost BaseModel objects to dictionaries for JSON serialization
|
||||
costs_dict = [
|
||||
cost.model_dump() if isinstance(cost, BaseModel) else cost
|
||||
for cost in costs
|
||||
]
|
||||
result.append({**block_instance.to_dict(), "costs": costs_dict})
|
||||
|
||||
# Use our JSON utility which properly handles complex types through to_dict conversion
|
||||
return dumps(result)
|
||||
|
||||
|
||||
@cached()
|
||||
async def _get_cached_blocks() -> str:
|
||||
"""
|
||||
Async cached function with thundering herd protection.
|
||||
On cache miss: runs heavy work in thread pool
|
||||
On cache hit: returns cached string immediately (no thread pool needed)
|
||||
"""
|
||||
# Only run in thread pool on cache miss - cache hits return immediately
|
||||
return await run_in_threadpool(_compute_blocks_sync)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/blocks",
|
||||
summary="List available blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(requires_user)],
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"items": {"additionalProperties": True, "type": "object"},
|
||||
"type": "array",
|
||||
"title": "Response Getv1List Available Blocks",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [
|
||||
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
|
||||
]
|
||||
async def get_graph_blocks() -> Response:
|
||||
# Cache hit: returns immediately, Cache miss: runs in thread pool
|
||||
content = await _get_cached_blocks()
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -282,15 +337,29 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
tags=["blocks"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
|
||||
async def execute_graph_block(
|
||||
block_id: str, data: BlockInput, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> CompletedBlockOutput:
|
||||
obj = get_block(block_id)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
# Get user context for block execution
|
||||
user = await get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found.")
|
||||
|
||||
user_context = UserContext(timezone=user.timezone)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
output = defaultdict(list)
|
||||
async for name, data in obj.execute(data):
|
||||
async for name, data in obj.execute(
|
||||
data,
|
||||
user_context=user_context,
|
||||
user_id=user_id,
|
||||
# Note: graph_exec_id and graph_id are not available for direct block execution
|
||||
):
|
||||
output[name].append(data)
|
||||
|
||||
# Record successful block execution with duration
|
||||
@@ -599,7 +668,13 @@ class DeleteGraphResponse(TypedDict):
|
||||
async def list_graphs(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphMeta]:
|
||||
return await graph_db.list_graphs(filter_by="active", user_id=user_id)
|
||||
paginated_result = await graph_db.list_graphs_paginated(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
filter_by="active",
|
||||
)
|
||||
return paginated_result.graphs
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -888,7 +963,12 @@ async def _stop_graph_run(
|
||||
async def list_graphs_executions(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
return await execution_db.get_graph_executions(user_id=user_id)
|
||||
paginated_result = await execution_db.get_graph_executions_paginated(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
)
|
||||
return paginated_result.executions
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
|
||||
@@ -110,8 +110,8 @@ def test_get_graph_blocks(
|
||||
|
||||
# Mock block costs
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_block_costs",
|
||||
return_value={"test-block": [{"cost": 10, "type": "credit"}]},
|
||||
"backend.data.credit.get_block_cost",
|
||||
return_value=[{"cost": 10, "type": "credit"}],
|
||||
)
|
||||
|
||||
response = client.get("/blocks")
|
||||
@@ -147,6 +147,15 @@ def test_execute_graph_block(
|
||||
return_value=mock_block,
|
||||
)
|
||||
|
||||
# Mock user for user_context
|
||||
mock_user = Mock()
|
||||
mock_user.timezone = "UTC"
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_user_by_id",
|
||||
return_value=mock_user,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"input_name": "test_input",
|
||||
"input_value": "test_value",
|
||||
@@ -270,8 +279,8 @@ def test_get_graphs(
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.graph_db.list_graphs",
|
||||
return_value=[mock_graph],
|
||||
"backend.data.graph.list_graphs_paginated",
|
||||
return_value=Mock(graphs=[mock_graph]),
|
||||
)
|
||||
|
||||
response = client.get("/graphs")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import functools
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import prisma
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
@@ -296,7 +296,7 @@ def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@functools.cache
|
||||
@cached()
|
||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
providers: dict[ProviderName, Provider] = {}
|
||||
|
||||
|
||||
@@ -101,7 +101,9 @@ async def list_library_agents(
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
@@ -185,7 +187,9 @@ async def list_favorite_library_agents(
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
@@ -417,7 +421,9 @@ async def create_library_agent(
|
||||
}
|
||||
},
|
||||
),
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
for graph_entry in graph_entries
|
||||
)
|
||||
@@ -642,7 +648,9 @@ async def add_store_agent_to_library(
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include(user_id),
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
f"Added graph #{graph.id} v{graph.version}"
|
||||
|
||||
@@ -177,7 +177,9 @@ async def test_add_agent_to_library(mocker):
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
include=library_agent_include("test-user"),
|
||||
include=library_agent_include(
|
||||
"test-user", include_nodes=False, include_executions=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import urllib.parse
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.db
|
||||
@@ -20,6 +21,117 @@ logger = logging.getLogger(__name__)
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
##############################################
|
||||
|
||||
|
||||
# Cache user profiles for 1 hour per user
|
||||
@cached(maxsize=1000, ttl_seconds=3600)
|
||||
async def _get_cached_user_profile(user_id: str):
|
||||
"""Cached helper to get user profile."""
|
||||
return await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
|
||||
|
||||
# Cache store agents list for 15 minutes
|
||||
# Different cache entries for different query combinations
|
||||
@cached(maxsize=5000, ttl_seconds=900)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: str | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=900)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
|
||||
|
||||
# Cache agent graphs for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_agent_graph(store_listing_version_id: str):
|
||||
"""Cached helper to get agent graph."""
|
||||
return await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache agent by version for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
|
||||
"""Cached helper to get store agent by version ID."""
|
||||
return await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache creators list for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual creator details for 1 hour
|
||||
@cached(maxsize=100, ttl_seconds=3600)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
|
||||
|
||||
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
|
||||
@cached(maxsize=500, ttl_seconds=300)
|
||||
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's agents."""
|
||||
return await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
|
||||
@cached(maxsize=500, ttl_seconds=3600)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's submissions."""
|
||||
return await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Profile Endpoints ############
|
||||
##############################################
|
||||
@@ -37,9 +149,10 @@ async def get_profile(
|
||||
):
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
Cached for 1 hour per user.
|
||||
"""
|
||||
try:
|
||||
profile = await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
profile = await _get_cached_user_profile(user_id)
|
||||
if profile is None:
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=404,
|
||||
@@ -85,6 +198,8 @@ async def update_or_create_profile(
|
||||
updated_profile = await backend.server.v2.store.db.update_profile(
|
||||
user_id=user_id, profile=profile
|
||||
)
|
||||
# Clear the cache for this user after profile update
|
||||
_get_cached_user_profile.cache_delete(user_id)
|
||||
return updated_profile
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update profile for user %s: %s", user_id, e)
|
||||
@@ -119,6 +234,7 @@ async def get_agents(
|
||||
):
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
Results are cached for 15 minutes.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
@@ -154,9 +270,9 @@ async def get_agents(
|
||||
)
|
||||
|
||||
try:
|
||||
agents = await backend.server.v2.store.db.get_store_agents(
|
||||
agents = await _get_cached_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
@@ -183,7 +299,8 @@ async def get_agents(
|
||||
)
|
||||
async def get_agent(username: str, agent_name: str):
|
||||
"""
|
||||
This is only used on the AgentDetails Page
|
||||
This is only used on the AgentDetails Page.
|
||||
Results are cached for 15 minutes.
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
@@ -191,7 +308,7 @@ async def get_agent(username: str, agent_name: str):
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
agent = await backend.server.v2.store.db.get_store_agent_details(
|
||||
agent = await _get_cached_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
@@ -214,11 +331,10 @@ async def get_agent(username: str, agent_name: str):
|
||||
async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: str):
|
||||
"""
|
||||
Get Agent Graph from Store Listing Version ID.
|
||||
Results are cached for 1 hour.
|
||||
"""
|
||||
try:
|
||||
graph = await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
graph = await _get_cached_agent_graph(store_listing_version_id)
|
||||
return graph
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting agent graph")
|
||||
@@ -238,11 +354,10 @@ async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: s
|
||||
async def get_store_agent(store_listing_version_id: str):
|
||||
"""
|
||||
Get Store Agent Details from Store Listing Version ID.
|
||||
Results are cached for 1 hour.
|
||||
"""
|
||||
try:
|
||||
agent = await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
agent = await _get_cached_store_agent_by_version(store_listing_version_id)
|
||||
return agent
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store agent")
|
||||
@@ -279,7 +394,7 @@ async def create_review(
|
||||
"""
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name)
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
# Create the review
|
||||
created_review = await backend.server.v2.store.db.create_store_review(
|
||||
user_id=user_id,
|
||||
@@ -320,6 +435,8 @@ async def get_creators(
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
Results are cached for 1 hour.
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
@@ -338,7 +455,7 @@ async def get_creators(
|
||||
)
|
||||
|
||||
try:
|
||||
creators = await backend.server.v2.store.db.get_store_creators(
|
||||
creators = await _get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
@@ -364,14 +481,13 @@ async def get_creator(
|
||||
username: str,
|
||||
):
|
||||
"""
|
||||
Get the details of a creator
|
||||
Get the details of a creator.
|
||||
Results are cached for 1 hour.
|
||||
- Creator Details Page
|
||||
"""
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
creator = await _get_cached_creator_details(username=username)
|
||||
return creator
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting creator details")
|
||||
@@ -386,6 +502,8 @@ async def get_creator(
|
||||
############################################
|
||||
############# Store Submissions ###############
|
||||
############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/myagents",
|
||||
summary="Get my agents",
|
||||
@@ -398,10 +516,12 @@ async def get_my_agents(
|
||||
page: typing.Annotated[int, fastapi.Query(ge=1)] = 1,
|
||||
page_size: typing.Annotated[int, fastapi.Query(ge=1)] = 20,
|
||||
):
|
||||
"""
|
||||
Get user's own agents.
|
||||
Results are cached for 5 minutes per user.
|
||||
"""
|
||||
try:
|
||||
agents = await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
agents = await _get_cached_my_agents(user_id, page=page, page_size=page_size)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting my agents")
|
||||
@@ -437,6 +557,14 @@ async def delete_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
# Clear submissions cache for this specific user after deletion
|
||||
if result:
|
||||
# Clear user's own agents cache - we don't know all page/size combinations
|
||||
for page in range(1, 20):
|
||||
# Clear user's submissions cache for common defaults
|
||||
_get_cached_submissions.cache_delete(user_id, page=page, page_size=20)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst deleting store submission")
|
||||
@@ -460,6 +588,7 @@ async def get_submissions(
|
||||
):
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
Results are cached for 1 hour per user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
@@ -482,10 +611,8 @@ async def get_submissions(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
try:
|
||||
listings = await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
listings = await _get_cached_submissions(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
return listings
|
||||
except Exception:
|
||||
@@ -523,7 +650,7 @@ async def create_submission(
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
try:
|
||||
return await backend.server.v2.store.db.create_store_submission(
|
||||
result = await backend.server.v2.store.db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
@@ -538,6 +665,13 @@ async def create_submission(
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
# Clear user's own agents cache - we don't know all page/size combinations
|
||||
for page in range(1, 20):
|
||||
# Clear user's submissions cache for common defaults
|
||||
_get_cached_submissions.cache_delete(user_id, page=page, page_size=20)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
@@ -572,7 +706,7 @@ async def edit_submission(
|
||||
Raises:
|
||||
HTTPException: If there is an error editing the submission
|
||||
"""
|
||||
return await backend.server.v2.store.db.edit_store_submission(
|
||||
result = await backend.server.v2.store.db.edit_store_submission(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
name=submission_request.name,
|
||||
@@ -586,6 +720,13 @@ async def edit_submission(
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
# Clear user's own agents cache - we don't know all page/size combinations
|
||||
for page in range(1, 20):
|
||||
# Clear user's submissions cache for common defaults
|
||||
_get_cached_submissions.cache_delete(user_id, page=page, page_size=20)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/media",
|
||||
@@ -737,3 +878,63 @@ async def download_agent_file(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Cache Management #############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/metrics/cache",
|
||||
summary="Get cache metrics in Prometheus format",
|
||||
tags=["store", "metrics"],
|
||||
response_class=fastapi.responses.PlainTextResponse,
|
||||
)
|
||||
async def get_cache_metrics():
|
||||
"""
|
||||
Get cache metrics in Prometheus text format.
|
||||
|
||||
Returns Prometheus-compatible metrics for monitoring cache performance.
|
||||
Metrics include size, maxsize, TTL, and hit rate for each cache.
|
||||
|
||||
Returns:
|
||||
str: Prometheus-formatted metrics text
|
||||
"""
|
||||
metrics = []
|
||||
|
||||
# Helper to add metrics for a cache
|
||||
def add_cache_metrics(cache_name: str, cache_func):
|
||||
info = cache_func.cache_info()
|
||||
# Cache size metric (dynamic - changes as items are cached/expired)
|
||||
metrics.append(f'store_cache_entries{{cache="{cache_name}"}} {info["size"]}')
|
||||
# Cache utilization percentage (dynamic - useful for monitoring)
|
||||
utilization = (
|
||||
(info["size"] / info["maxsize"] * 100) if info["maxsize"] > 0 else 0
|
||||
)
|
||||
metrics.append(
|
||||
f'store_cache_utilization_percent{{cache="{cache_name}"}} {utilization:.2f}'
|
||||
)
|
||||
|
||||
# Add metrics for each cache
|
||||
add_cache_metrics("user_profile", _get_cached_user_profile)
|
||||
add_cache_metrics("store_agents", _get_cached_store_agents)
|
||||
add_cache_metrics("agent_details", _get_cached_agent_details)
|
||||
add_cache_metrics("agent_graph", _get_cached_agent_graph)
|
||||
add_cache_metrics("agent_by_version", _get_cached_store_agent_by_version)
|
||||
add_cache_metrics("store_creators", _get_cached_store_creators)
|
||||
add_cache_metrics("creator_details", _get_cached_creator_details)
|
||||
add_cache_metrics("my_agents", _get_cached_my_agents)
|
||||
add_cache_metrics("submissions", _get_cached_submissions)
|
||||
|
||||
# Add metadata/help text at the beginning
|
||||
prometheus_output = [
|
||||
"# HELP store_cache_entries Number of entries currently in cache",
|
||||
"# TYPE store_cache_entries gauge",
|
||||
"# HELP store_cache_utilization_percent Cache utilization as percentage (0-100)",
|
||||
"# TYPE store_cache_utilization_percent gauge",
|
||||
"", # Empty line before metrics
|
||||
]
|
||||
prometheus_output.extend(metrics)
|
||||
|
||||
return "\n".join(prometheus_output)
|
||||
|
||||
@@ -0,0 +1,351 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite for verifying cache_delete functionality in store routes.
|
||||
Tests that specific cache entries can be deleted while preserving others.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.store import routes
|
||||
from backend.server.v2.store.model import (
|
||||
ProfileDetails,
|
||||
StoreAgent,
|
||||
StoreAgentDetails,
|
||||
StoreAgentsResponse,
|
||||
)
|
||||
from backend.util.models import Pagination
|
||||
|
||||
|
||||
class TestCacheDeletion:
|
||||
"""Test cache deletion functionality for store routes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_agents_cache_delete(self):
|
||||
"""Test that specific agent list cache entries can be deleted."""
|
||||
# Mock the database function
|
||||
mock_response = StoreAgentsResponse(
|
||||
agents=[
|
||||
StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="https://example.com/image.jpg",
|
||||
creator="testuser",
|
||||
creator_avatar="https://example.com/avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=Pagination(
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
# Clear cache first
|
||||
routes._get_cached_store_agents.cache_clear()
|
||||
|
||||
# First call - should hit database
|
||||
result1 = await routes._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query="test",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
assert mock_db.call_count == 1
|
||||
assert result1.agents[0].agent_name == "Test Agent"
|
||||
|
||||
# Second call with same params - should use cache
|
||||
await routes._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query="test",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
assert mock_db.call_count == 1 # No additional DB call
|
||||
|
||||
# Third call with different params - should hit database
|
||||
await routes._get_cached_store_agents(
|
||||
featured=True, # Different param
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query="test",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
assert mock_db.call_count == 2 # New DB call
|
||||
|
||||
# Delete specific cache entry
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query="test",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
assert deleted is True # Entry was deleted
|
||||
|
||||
# Try to delete non-existent entry
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=False,
|
||||
creator="nonexistent",
|
||||
sorted_by=None,
|
||||
search_query="test",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
assert deleted is False # Entry didn't exist
|
||||
|
||||
# Call with deleted params - should hit database again
|
||||
await routes._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query="test",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
assert mock_db.call_count == 3 # New DB call after deletion
|
||||
|
||||
# Call with featured=True - should still be cached
|
||||
await routes._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query="test",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
assert mock_db.call_count == 3 # No additional DB call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_details_cache_delete(self):
|
||||
"""Test that specific agent details cache entries can be deleted."""
|
||||
mock_response = StoreAgentDetails(
|
||||
store_listing_version_id="version1",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="https://example.com/video.mp4",
|
||||
agent_image=["https://example.com/image.jpg"],
|
||||
creator="testuser",
|
||||
creator_avatar="https://example.com/avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["productivity"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=[],
|
||||
last_updated=datetime.datetime(2024, 1, 1),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_agent_details",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
# Clear cache first
|
||||
routes._get_cached_agent_details.cache_clear()
|
||||
|
||||
# First call - should hit database
|
||||
await routes._get_cached_agent_details(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
await routes._get_cached_agent_details(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert mock_db.call_count == 1 # No additional DB call
|
||||
|
||||
# Delete specific entry
|
||||
deleted = routes._get_cached_agent_details.cache_delete(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert deleted is True
|
||||
|
||||
# Call again - should hit database
|
||||
await routes._get_cached_agent_details(
|
||||
username="testuser", agent_name="testagent"
|
||||
)
|
||||
assert mock_db.call_count == 2 # New DB call after deletion
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_profile_cache_delete(self):
|
||||
"""Test that user profile cache entries can be deleted."""
|
||||
mock_response = ProfileDetails(
|
||||
name="Test User",
|
||||
username="testuser",
|
||||
description="Test profile",
|
||||
links=["https://example.com"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_user_profile",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
# Clear cache first
|
||||
routes._get_cached_user_profile.cache_clear()
|
||||
|
||||
# First call - should hit database
|
||||
await routes._get_cached_user_profile("user123")
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
await routes._get_cached_user_profile("user123")
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Different user - should hit database
|
||||
await routes._get_cached_user_profile("user456")
|
||||
assert mock_db.call_count == 2
|
||||
|
||||
# Delete specific user's cache
|
||||
deleted = routes._get_cached_user_profile.cache_delete("user123")
|
||||
assert deleted is True
|
||||
|
||||
# user123 should hit database again
|
||||
await routes._get_cached_user_profile("user123")
|
||||
assert mock_db.call_count == 3
|
||||
|
||||
# user456 should still be cached
|
||||
await routes._get_cached_user_profile("user456")
|
||||
assert mock_db.call_count == 3 # No additional DB call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info_after_deletions(self):
|
||||
"""Test that cache_info correctly reflects deletions."""
|
||||
# Clear all caches first
|
||||
routes._get_cached_store_agents.cache_clear()
|
||||
|
||||
mock_response = StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
# Add multiple entries
|
||||
for i in range(5):
|
||||
await routes._get_cached_store_agents(
|
||||
featured=False,
|
||||
creator=f"creator{i}",
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Check cache size
|
||||
info = routes._get_cached_store_agents.cache_info()
|
||||
assert info["size"] == 5
|
||||
|
||||
# Delete some entries
|
||||
for i in range(2):
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=False,
|
||||
creator=f"creator{i}",
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
assert deleted is True
|
||||
|
||||
# Check cache size after deletion
|
||||
info = routes._get_cached_store_agents.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_with_complex_params(self):
|
||||
"""Test cache deletion with various parameter combinations."""
|
||||
mock_response = StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=Pagination(
|
||||
total_items=0,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.server.v2.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_db:
|
||||
routes._get_cached_store_agents.cache_clear()
|
||||
|
||||
# Test with all parameters
|
||||
await routes._get_cached_store_agents(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
page_size=50,
|
||||
)
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
# Delete with exact same parameters
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
page_size=50,
|
||||
)
|
||||
assert deleted is True
|
||||
|
||||
# Try to delete with slightly different parameters
|
||||
deleted = routes._get_cached_store_agents.cache_delete(
|
||||
featured=True,
|
||||
creator="testuser",
|
||||
sorted_by="rating",
|
||||
search_query="AI assistant",
|
||||
category="productivity",
|
||||
page=2,
|
||||
page_size=51, # Different page_size
|
||||
)
|
||||
assert deleted is False # Different parameters, not in cache
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -2,10 +2,9 @@
|
||||
Centralized service client helpers with thread caching.
|
||||
"""
|
||||
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.cache import async_cache, thread_cached
|
||||
from autogpt_libs.utils.cache import cached, thread_cached
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -119,7 +118,7 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
|
||||
# ============ Supabase Clients ============ #
|
||||
|
||||
|
||||
@cache
|
||||
@cached()
|
||||
def get_supabase() -> "Client":
|
||||
"""Get a process-cached synchronous Supabase client instance."""
|
||||
from supabase import create_client
|
||||
@@ -129,7 +128,7 @@ def get_supabase() -> "Client":
|
||||
)
|
||||
|
||||
|
||||
@async_cache
|
||||
@cached()
|
||||
async def get_async_supabase() -> "AClient":
|
||||
"""Get a process-cached asynchronous Supabase client instance."""
|
||||
from supabase import create_async_client
|
||||
|
||||
@@ -9,6 +9,7 @@ import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Tuple
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio import storage as async_gcs_storage
|
||||
from google.cloud import storage as gcs_storage
|
||||
|
||||
@@ -38,20 +39,59 @@ class CloudStorageHandler:
|
||||
self.config = config
|
||||
self._async_gcs_client = None
|
||||
self._sync_gcs_client = None # Only for signed URLs
|
||||
self._session = None
|
||||
|
||||
async def _get_async_gcs_client(self):
|
||||
"""Get or create async GCS client, ensuring it's created in proper async context."""
|
||||
# Check if we already have a client
|
||||
if self._async_gcs_client is not None:
|
||||
return self._async_gcs_client
|
||||
|
||||
current_task = asyncio.current_task()
|
||||
if not current_task:
|
||||
# If we're not in a task, create a temporary client
|
||||
logger.warning(
|
||||
"[CloudStorage] Creating GCS client outside of task context - using temporary client"
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=300)
|
||||
session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
connector=aiohttp.TCPConnector(limit=100, force_close=False),
|
||||
)
|
||||
return async_gcs_storage.Storage(session=session)
|
||||
|
||||
# Create a reusable session with proper configuration
|
||||
# Key fix: Don't set timeout on session, let gcloud-aio handle it
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(
|
||||
limit=100, # Connection pool limit
|
||||
force_close=False, # Reuse connections
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Create the GCS client with our session
|
||||
# The key is NOT setting timeout on the session but letting the library handle it
|
||||
self._async_gcs_client = async_gcs_storage.Storage(session=self._session)
|
||||
|
||||
def _get_async_gcs_client(self):
|
||||
"""Lazy initialization of async GCS client."""
|
||||
if self._async_gcs_client is None:
|
||||
# Use Application Default Credentials (ADC)
|
||||
self._async_gcs_client = async_gcs_storage.Storage()
|
||||
return self._async_gcs_client
|
||||
|
||||
async def close(self):
|
||||
"""Close all client connections properly."""
|
||||
if self._async_gcs_client is not None:
|
||||
await self._async_gcs_client.close()
|
||||
try:
|
||||
await self._async_gcs_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[CloudStorage] Error closing GCS client: {e}")
|
||||
self._async_gcs_client = None
|
||||
|
||||
if self._session is not None:
|
||||
try:
|
||||
await self._session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[CloudStorage] Error closing session: {e}")
|
||||
self._session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
@@ -141,7 +181,7 @@ class CloudStorageHandler:
|
||||
if user_id and graph_exec_id:
|
||||
raise ValueError("Provide either user_id OR graph_exec_id, not both")
|
||||
|
||||
async_client = self._get_async_gcs_client()
|
||||
async_client = await self._get_async_gcs_client()
|
||||
|
||||
# Generate unique path with appropriate scope
|
||||
unique_id = str(uuid.uuid4())
|
||||
@@ -203,6 +243,15 @@ class CloudStorageHandler:
|
||||
self, path: str, user_id: str | None = None, graph_exec_id: str | None = None
|
||||
) -> bytes:
|
||||
"""Retrieve file from Google Cloud Storage with authorization."""
|
||||
# Log context for debugging
|
||||
current_task = asyncio.current_task()
|
||||
logger.info(
|
||||
f"[CloudStorage]"
|
||||
f"_retrieve_file_gcs called - "
|
||||
f"current_task: {current_task}, "
|
||||
f"in_task: {current_task is not None}"
|
||||
)
|
||||
|
||||
# Parse bucket and blob name from path
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
@@ -213,13 +262,65 @@ class CloudStorageHandler:
|
||||
# Authorization check
|
||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||
|
||||
async_client = self._get_async_gcs_client()
|
||||
# Use a fresh client for each download to avoid session issues
|
||||
# This is less efficient but more reliable with the executor's event loop
|
||||
logger.info("[CloudStorage] Creating fresh GCS client for download")
|
||||
|
||||
# Create a new session specifically for this download
|
||||
session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
||||
)
|
||||
|
||||
async_client = None
|
||||
try:
|
||||
# Download content using pure async client
|
||||
# Create a new GCS client with the fresh session
|
||||
async_client = async_gcs_storage.Storage(session=session)
|
||||
|
||||
logger.info(
|
||||
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
||||
)
|
||||
|
||||
# Download content using the fresh client
|
||||
content = await async_client.download(bucket_name, blob_name)
|
||||
logger.info(
|
||||
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
||||
)
|
||||
|
||||
# Clean up
|
||||
await async_client.close()
|
||||
await session.close()
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
# Always try to clean up
|
||||
if async_client is not None:
|
||||
try:
|
||||
await async_client.close()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(
|
||||
f"[CloudStorage] Error closing GCS client: {cleanup_error}"
|
||||
)
|
||||
try:
|
||||
await session.close()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}")
|
||||
|
||||
# Log the specific error for debugging
|
||||
logger.error(
|
||||
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
||||
f"error_type: {type(e).__name__}, "
|
||||
f"bucket: {bucket_name}, blob: redacted for privacy"
|
||||
)
|
||||
|
||||
# Special handling for timeout error
|
||||
if "Timeout context manager" in str(e):
|
||||
logger.critical(
|
||||
f"[CloudStorage] TIMEOUT ERROR in GCS download! "
|
||||
f"current_task: {current_task}, "
|
||||
f"bucket: {bucket_name}, blob: redacted for privacy"
|
||||
)
|
||||
|
||||
# Convert gcloud-aio exceptions to standard ones
|
||||
if "404" in str(e) or "Not Found" in str(e):
|
||||
raise FileNotFoundError(f"File not found: gcs://{path}")
|
||||
@@ -303,7 +404,7 @@ class CloudStorageHandler:
|
||||
|
||||
# Legacy uploads directory (uploads/*) - allow for backwards compatibility with warning
|
||||
# Note: We already validated it starts with "uploads/" above, so this is guaranteed to match
|
||||
logger.warning(f"Accessing legacy upload path: {blob_name}")
|
||||
logger.warning(f"[CloudStorage] Accessing legacy upload path: {blob_name}")
|
||||
return
|
||||
|
||||
async def generate_signed_url(
|
||||
@@ -391,7 +492,7 @@ class CloudStorageHandler:
|
||||
if not self.config.gcs_bucket_name:
|
||||
raise ValueError("GCS_BUCKET_NAME not configured")
|
||||
|
||||
async_client = self._get_async_gcs_client()
|
||||
async_client = await self._get_async_gcs_client()
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
@@ -431,7 +532,7 @@ class CloudStorageHandler:
|
||||
except Exception as e:
|
||||
# Log specific errors for debugging
|
||||
logger.warning(
|
||||
f"Failed to process file {blob_name} during cleanup: {e}"
|
||||
f"[CloudStorage] Failed to process file {blob_name} during cleanup: {e}"
|
||||
)
|
||||
# Skip files with invalid metadata or delete errors
|
||||
pass
|
||||
@@ -447,7 +548,7 @@ class CloudStorageHandler:
|
||||
|
||||
except Exception as e:
|
||||
# Log the error for debugging but continue operation
|
||||
logger.error(f"Cleanup operation failed: {e}")
|
||||
logger.error(f"[CloudStorage] Cleanup operation failed: {e}")
|
||||
# Return 0 - we'll try again next cleanup cycle
|
||||
return 0
|
||||
|
||||
@@ -476,7 +577,7 @@ class CloudStorageHandler:
|
||||
|
||||
bucket_name, blob_name = parts
|
||||
|
||||
async_client = self._get_async_gcs_client()
|
||||
async_client = await self._get_async_gcs_client()
|
||||
|
||||
try:
|
||||
# Get object metadata using pure async client
|
||||
@@ -490,11 +591,15 @@ class CloudStorageHandler:
|
||||
except Exception as e:
|
||||
# If file doesn't exist or we can't read metadata
|
||||
if "404" in str(e) or "Not Found" in str(e):
|
||||
logger.debug(f"File not found during expiration check: {blob_name}")
|
||||
logger.warning(
|
||||
f"[CloudStorage] File not found during expiration check: {blob_name}"
|
||||
)
|
||||
return True # File doesn't exist, consider it expired
|
||||
|
||||
# Log other types of errors for debugging
|
||||
logger.warning(f"Failed to check expiration for {blob_name}: {e}")
|
||||
logger.warning(
|
||||
f"[CloudStorage] Failed to check expiration for {blob_name}: {e}"
|
||||
)
|
||||
# If we can't read metadata for other reasons, assume not expired
|
||||
return False
|
||||
|
||||
@@ -544,11 +649,15 @@ async def cleanup_expired_files_async() -> int:
|
||||
# Use cleanup lock to prevent concurrent cleanup operations
|
||||
async with _cleanup_lock:
|
||||
try:
|
||||
logger.info("Starting cleanup of expired cloud storage files")
|
||||
logger.info(
|
||||
"[CloudStorage] Starting cleanup of expired cloud storage files"
|
||||
)
|
||||
handler = await get_cloud_storage_handler()
|
||||
deleted_count = await handler.delete_expired_files()
|
||||
logger.info(f"Cleaned up {deleted_count} expired files from cloud storage")
|
||||
logger.info(
|
||||
f"[CloudStorage] Cleaned up {deleted_count} expired files from cloud storage"
|
||||
)
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cloud storage cleanup: {e}")
|
||||
logger.error(f"[CloudStorage] Error during cloud storage cleanup: {e}")
|
||||
return 0
|
||||
|
||||
@@ -72,16 +72,17 @@ class TestCloudStorageHandler:
|
||||
assert call_args[0][2] == content # file content
|
||||
assert "metadata" in call_args[1] # metadata argument
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@patch("backend.util.cloud_storage.async_gcs_storage.Storage")
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_gcs(self, mock_get_async_client, handler):
|
||||
async def test_retrieve_file_gcs(self, mock_storage_class, handler):
|
||||
"""Test retrieving file from GCS."""
|
||||
# Mock async GCS client
|
||||
# Mock async GCS client instance
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
mock_storage_class.return_value = mock_async_client
|
||||
|
||||
# Mock the download method
|
||||
# Mock the download and close methods
|
||||
mock_async_client.download = AsyncMock(return_value=b"test content")
|
||||
mock_async_client.close = AsyncMock()
|
||||
|
||||
result = await handler.retrieve_file(
|
||||
"gcs://test-bucket/uploads/system/uuid123/file.txt"
|
||||
@@ -92,16 +93,17 @@ class TestCloudStorageHandler:
|
||||
"test-bucket", "uploads/system/uuid123/file.txt"
|
||||
)
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@patch("backend.util.cloud_storage.async_gcs_storage.Storage")
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_not_found(self, mock_get_async_client, handler):
|
||||
async def test_retrieve_file_not_found(self, mock_storage_class, handler):
|
||||
"""Test retrieving non-existent file from GCS."""
|
||||
# Mock async GCS client
|
||||
# Mock async GCS client instance
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
mock_storage_class.return_value = mock_async_client
|
||||
|
||||
# Mock the download method to raise a 404 exception
|
||||
mock_async_client.download = AsyncMock(side_effect=Exception("404 Not Found"))
|
||||
mock_async_client.close = AsyncMock()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await handler.retrieve_file(
|
||||
@@ -287,14 +289,15 @@ class TestCloudStorageHandler:
|
||||
):
|
||||
handler._validate_file_access("invalid/path/file.txt", "user123")
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@patch("backend.util.cloud_storage.async_gcs_storage.Storage")
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_with_authorization(self, mock_get_client, handler):
|
||||
async def test_retrieve_file_with_authorization(self, mock_storage_class, handler):
|
||||
"""Test file retrieval with authorization."""
|
||||
# Mock async GCS client
|
||||
# Mock async GCS client instance
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_storage_class.return_value = mock_client
|
||||
mock_client.download = AsyncMock(return_value=b"test content")
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
# Test successful retrieval of user's own file
|
||||
result = await handler.retrieve_file(
|
||||
@@ -412,18 +415,19 @@ class TestCloudStorageHandler:
|
||||
"uploads/executions/exec123/uuid456/file.txt", graph_exec_id="exec456"
|
||||
)
|
||||
|
||||
@patch.object(CloudStorageHandler, "_get_async_gcs_client")
|
||||
@patch("backend.util.cloud_storage.async_gcs_storage.Storage")
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_file_with_exec_authorization(
|
||||
self, mock_get_async_client, handler
|
||||
self, mock_storage_class, handler
|
||||
):
|
||||
"""Test file retrieval with execution authorization."""
|
||||
# Mock async GCS client
|
||||
# Mock async GCS client instance
|
||||
mock_async_client = AsyncMock()
|
||||
mock_get_async_client.return_value = mock_async_client
|
||||
mock_storage_class.return_value = mock_async_client
|
||||
|
||||
# Mock the download method
|
||||
# Mock the download and close methods
|
||||
mock_async_client.download = AsyncMock(return_value=b"test content")
|
||||
mock_async_client.close = AsyncMock()
|
||||
|
||||
# Test successful retrieval of execution's own file
|
||||
result = await handler.retrieve_file(
|
||||
|
||||
@@ -5,7 +5,7 @@ from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
import ldclient
|
||||
from autogpt_libs.utils.cache import async_ttl_cache
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
@@ -72,7 +72,7 @@ def shutdown_launchdarkly() -> None:
|
||||
logger.info("LaunchDarkly client closed successfully")
|
||||
|
||||
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=86400) # 1000 entries, 24 hours TTL
|
||||
@cached(maxsize=1000, ttl_seconds=86400) # 1000 entries, 24 hours TTL
|
||||
async def _fetch_user_context_data(user_id: str) -> Context:
|
||||
"""
|
||||
Fetch user context for LaunchDarkly from Supabase.
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Type, TypeGuard, TypeVar, overload
|
||||
|
||||
import jsonschema
|
||||
import orjson
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from prisma import Json
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .type import type_match
|
||||
|
||||
# Precompiled regex to remove PostgreSQL-incompatible control characters
|
||||
# Removes \u0000-\u0008, \u000B-\u000C, \u000E-\u001F, \u007F (keeps tab \u0009, newline \u000A, carriage return \u000D)
|
||||
POSTGRES_CONTROL_CHARS = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]")
|
||||
|
||||
|
||||
def to_dict(data) -> dict:
|
||||
if isinstance(data, BaseModel):
|
||||
@@ -15,7 +21,9 @@ def to_dict(data) -> dict:
|
||||
return jsonable_encoder(data)
|
||||
|
||||
|
||||
def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
def dumps(
|
||||
data: Any, *args: Any, indent: int | None = None, option: int = 0, **kwargs: Any
|
||||
) -> str:
|
||||
"""
|
||||
Serialize data to JSON string with automatic conversion of Pydantic models and complex types.
|
||||
|
||||
@@ -28,9 +36,13 @@ def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
data : Any
|
||||
The data to serialize. Can be any type including Pydantic models, dicts, lists, etc.
|
||||
*args : Any
|
||||
Additional positional arguments passed to json.dumps()
|
||||
Additional positional arguments
|
||||
indent : int | None
|
||||
If not None, pretty-print with indentation
|
||||
option : int
|
||||
orjson option flags (default: 0)
|
||||
**kwargs : Any
|
||||
Additional keyword arguments passed to json.dumps() (e.g., indent, separators)
|
||||
Additional keyword arguments. Supported: default, ensure_ascii, separators, indent
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -45,7 +57,21 @@ def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
>>> dumps(pydantic_model_instance, indent=2)
|
||||
'{\n "field1": "value1",\n "field2": "value2"\n}'
|
||||
"""
|
||||
return json.dumps(to_dict(data), *args, **kwargs)
|
||||
serializable_data = to_dict(data)
|
||||
|
||||
# Handle indent parameter
|
||||
if indent is not None or kwargs.get("indent") is not None:
|
||||
option |= orjson.OPT_INDENT_2
|
||||
|
||||
# orjson only accepts specific parameters, filter out stdlib json params
|
||||
# ensure_ascii: orjson always produces UTF-8 (better than ASCII)
|
||||
# separators: orjson uses compact separators by default
|
||||
supported_orjson_params = {"default"}
|
||||
orjson_kwargs = {k: v for k, v in kwargs.items() if k in supported_orjson_params}
|
||||
|
||||
return orjson.dumps(serializable_data, option=option, **orjson_kwargs).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -62,9 +88,8 @@ def loads(data: str | bytes, *args, **kwargs) -> Any: ...
|
||||
def loads(
|
||||
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
|
||||
) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
parsed = orjson.loads(data)
|
||||
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
return parsed
|
||||
@@ -99,16 +124,19 @@ def convert_pydantic_to_json(output_data: Any) -> Any:
|
||||
|
||||
|
||||
def SafeJson(data: Any) -> Json:
|
||||
"""Safely serialize data and return Prisma's Json type."""
|
||||
"""
|
||||
Safely serialize data and return Prisma's Json type.
|
||||
Sanitizes null bytes to prevent PostgreSQL 22P05 errors.
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
return Json(
|
||||
data.model_dump(
|
||||
mode="json",
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
json_string = data.model_dump_json(
|
||||
warnings="error",
|
||||
exclude_none=True,
|
||||
fallback=lambda v: None,
|
||||
)
|
||||
# Round-trip through JSON to ensure proper serialization with fallback for non-serializable values
|
||||
json_string = dumps(data, default=lambda v: None)
|
||||
return Json(json.loads(json_string))
|
||||
else:
|
||||
json_string = dumps(data, default=lambda v: None)
|
||||
|
||||
# Remove PostgreSQL-incompatible control characters in single regex operation
|
||||
sanitized_json = POSTGRES_CONTROL_CHARS.sub("", json_string)
|
||||
return Json(json.loads(sanitized_json))
|
||||
|
||||
@@ -4,6 +4,7 @@ from enum import Enum
|
||||
import sentry_sdk
|
||||
from pydantic import SecretStr
|
||||
from sentry_sdk.integrations.anthropic import AnthropicIntegration
|
||||
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
from backend.util.settings import Settings
|
||||
@@ -25,6 +26,7 @@ def sentry_init():
|
||||
environment=f"app:{settings.config.app_env.value}-behave:{settings.config.behave_as.value}",
|
||||
_experiments={"enable_logs": True},
|
||||
integrations=[
|
||||
AsyncioIntegration(),
|
||||
LoggingIntegration(sentry_logs_level=logging.INFO),
|
||||
AnthropicIntegration(
|
||||
include_prompts=False,
|
||||
|
||||
@@ -17,6 +17,37 @@ from backend.util.process import get_service_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Alert threshold for excessive retries
|
||||
EXCESSIVE_RETRY_THRESHOLD = 50
|
||||
|
||||
|
||||
def _send_critical_retry_alert(
|
||||
func_name: str, attempt_number: int, exception: Exception, context: str = ""
|
||||
):
|
||||
"""Send alert when a function is approaching the retry failure threshold."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from backend.util.clients import get_notification_manager_client
|
||||
|
||||
notification_client = get_notification_manager_client()
|
||||
|
||||
prefix = f"{context}: " if context else ""
|
||||
alert_msg = (
|
||||
f"🚨 CRITICAL: Operation Approaching Failure Threshold: {prefix}'{func_name}'\n\n"
|
||||
f"Current attempt: {attempt_number}/{EXCESSIVE_RETRY_THRESHOLD}\n"
|
||||
f"Error: {type(exception).__name__}: {exception}\n\n"
|
||||
f"This operation is about to fail permanently. Investigate immediately."
|
||||
)
|
||||
|
||||
notification_client.discord_system_alert(alert_msg)
|
||||
logger.critical(
|
||||
f"CRITICAL ALERT SENT: Operation {func_name} at attempt {attempt_number}"
|
||||
)
|
||||
|
||||
except Exception as alert_error:
|
||||
logger.error(f"Failed to send critical retry alert: {alert_error}")
|
||||
# Don't let alerting failures break the main flow
|
||||
|
||||
|
||||
def _create_retry_callback(context: str = ""):
|
||||
"""Create a retry callback with optional context."""
|
||||
@@ -29,17 +60,22 @@ def _create_retry_callback(context: str = ""):
|
||||
prefix = f"{context}: " if context else ""
|
||||
|
||||
if retry_state.outcome.failed and retry_state.next_action is None:
|
||||
# Final failure
|
||||
# Final failure - just log the error (alert was already sent at excessive threshold)
|
||||
logger.error(
|
||||
f"{prefix}Giving up after {attempt_number} attempts for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
else:
|
||||
# Retry attempt
|
||||
logger.warning(
|
||||
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
# Retry attempt - send critical alert only once at threshold
|
||||
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
|
||||
_send_critical_retry_alert(
|
||||
func_name, attempt_number, exception, context
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
|
||||
f"{type(exception).__name__}: {exception}"
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property, update_wrapper
|
||||
from functools import update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
@@ -43,6 +43,7 @@ api_host = config.pyro_host
|
||||
api_comm_retry = config.pyro_client_comm_retry
|
||||
api_comm_timeout = config.pyro_client_comm_timeout
|
||||
api_call_timeout = config.rpc_client_call_timeout
|
||||
api_comm_max_wait = config.pyro_client_max_wait
|
||||
|
||||
|
||||
def _validate_no_prisma_objects(obj: Any, path: str = "result") -> None:
|
||||
@@ -352,7 +353,7 @@ def get_service_client(
|
||||
# Use preconfigured retry decorator for service communication
|
||||
return create_retry_decorator(
|
||||
max_attempts=api_comm_retry,
|
||||
max_wait=5.0,
|
||||
max_wait=api_comm_max_wait,
|
||||
context="Service communication",
|
||||
exclude_exceptions=(
|
||||
# Don't retry these specific exceptions that won't be fixed by retrying
|
||||
@@ -374,6 +375,8 @@ def get_service_client(
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self._connection_failure_count = 0
|
||||
self._last_client_reset = 0
|
||||
self._async_clients = {} # None key for default async client
|
||||
self._sync_clients = {} # For sync clients (no event loop concept)
|
||||
|
||||
def _create_sync_client(self) -> httpx.Client:
|
||||
return httpx.Client(
|
||||
@@ -397,13 +400,33 @@ def get_service_client(
|
||||
),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
return self._create_sync_client()
|
||||
"""Get the sync client (thread-safe singleton)."""
|
||||
# Use service name as key for better identification
|
||||
service_name = service_client_type.get_service_type().__name__
|
||||
if client := self._sync_clients.get(service_name):
|
||||
return client
|
||||
return self._sync_clients.setdefault(
|
||||
service_name, self._create_sync_client()
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
return self._create_async_client()
|
||||
"""Get the appropriate async client for the current context.
|
||||
|
||||
Returns per-event-loop client when in async context,
|
||||
falls back to default client otherwise.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop, use None as default key
|
||||
loop = None
|
||||
|
||||
if client := self._async_clients.get(loop):
|
||||
return client
|
||||
return self._async_clients.setdefault(loop, self._create_async_client())
|
||||
|
||||
def _handle_connection_error(self, error: Exception) -> None:
|
||||
"""Handle connection errors and implement self-healing"""
|
||||
@@ -422,10 +445,8 @@ def get_service_client(
|
||||
|
||||
# Clear cached clients to force recreation on next access
|
||||
# Only recreate when there's actually a problem
|
||||
if hasattr(self, "sync_client"):
|
||||
delattr(self, "sync_client")
|
||||
if hasattr(self, "async_client"):
|
||||
delattr(self, "async_client")
|
||||
self._sync_clients.clear()
|
||||
self._async_clients.clear()
|
||||
|
||||
# Reset counters
|
||||
self._connection_failure_count = 0
|
||||
@@ -491,28 +512,37 @@ def get_service_client(
|
||||
raise
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
await self.async_client.aclose()
|
||||
# Close all sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
self._sync_clients.clear()
|
||||
|
||||
# Close all async clients (including default with None key)
|
||||
for client in self._async_clients.values():
|
||||
await client.aclose()
|
||||
self._async_clients.clear()
|
||||
|
||||
def close(self) -> None:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
# Note: Cannot close async client synchronously
|
||||
# Close all sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
self._sync_clients.clear()
|
||||
# Note: Cannot close async clients synchronously
|
||||
# They will be cleaned up by garbage collection
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup HTTP clients on garbage collection to prevent resource leaks."""
|
||||
try:
|
||||
if hasattr(self, "sync_client"):
|
||||
self.sync_client.close()
|
||||
if hasattr(self, "async_client"):
|
||||
# Note: Can't await in __del__, so we just close sync
|
||||
# The async client will be cleaned up by garbage collection
|
||||
# Close any remaining sync clients
|
||||
for client in self._sync_clients.values():
|
||||
client.close()
|
||||
|
||||
# Warn if async clients weren't properly closed
|
||||
if self._async_clients:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"DynamicClient async client not explicitly closed. "
|
||||
"DynamicClient async clients not explicitly closed. "
|
||||
"Call aclose() before destroying the client.",
|
||||
ResourceWarning,
|
||||
stacklevel=2,
|
||||
|
||||
@@ -59,6 +59,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for graph execution.",
|
||||
)
|
||||
|
||||
# FastAPI Thread Pool Configuration
|
||||
# IMPORTANT: FastAPI automatically offloads ALL sync functions to a thread pool:
|
||||
# - Sync endpoint functions (def instead of async def)
|
||||
# - Sync dependency functions (def instead of async def)
|
||||
# - Manually called run_in_threadpool() operations
|
||||
# Default thread pool size is only 40, which becomes a bottleneck under high concurrency
|
||||
fastapi_thread_pool_size: int = Field(
|
||||
default=60,
|
||||
ge=40,
|
||||
le=500,
|
||||
description="Thread pool size for FastAPI sync operations. All sync endpoints and dependencies automatically use this pool. Higher values support more concurrent sync operations but use more memory.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
@@ -68,9 +81,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The default timeout in seconds, for Pyro client connections.",
|
||||
)
|
||||
pyro_client_comm_retry: int = Field(
|
||||
default=5,
|
||||
default=100,
|
||||
description="The default number of retries for Pyro client connections.",
|
||||
)
|
||||
pyro_client_max_wait: float = Field(
|
||||
default=30.0,
|
||||
description="The maximum wait time in seconds for Pyro client retries.",
|
||||
)
|
||||
rpc_client_call_timeout: int = Field(
|
||||
default=300,
|
||||
description="The default timeout in seconds, for RPC client calls.",
|
||||
@@ -123,6 +140,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=5 * 60,
|
||||
description="Time in seconds after which the execution stuck on QUEUED status is considered late.",
|
||||
)
|
||||
cluster_lock_timeout: int = Field(
|
||||
default=300,
|
||||
description="Cluster lock timeout in seconds for graph execution coordination.",
|
||||
)
|
||||
execution_late_notification_checkrange_secs: int = Field(
|
||||
default=60 * 60,
|
||||
description="Time in seconds for how far back to check for the late executions.",
|
||||
|
||||
@@ -215,3 +215,29 @@ class TestSafeJson:
|
||||
}
|
||||
result = SafeJson(data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
def test_control_character_sanitization(self):
|
||||
"""Test that PostgreSQL-incompatible control characters are sanitized by SafeJson."""
|
||||
# Test data with problematic control characters that would cause PostgreSQL errors
|
||||
problematic_data = {
|
||||
"null_byte": "data with \x00 null",
|
||||
"bell_char": "data with \x07 bell",
|
||||
"form_feed": "data with \x0C feed",
|
||||
"escape_char": "data with \x1B escape",
|
||||
"delete_char": "data with \x7F delete",
|
||||
}
|
||||
|
||||
# SafeJson should successfully process data with control characters
|
||||
result = SafeJson(problematic_data)
|
||||
assert isinstance(result, Json)
|
||||
|
||||
# Test that safe whitespace characters are preserved
|
||||
safe_data = {
|
||||
"with_tab": "text with \t tab",
|
||||
"with_newline": "text with \n newline",
|
||||
"with_carriage_return": "text with \r carriage return",
|
||||
"normal_text": "completely normal text",
|
||||
}
|
||||
|
||||
safe_result = SafeJson(safe_data)
|
||||
assert isinstance(safe_result, Json)
|
||||
|
||||
@@ -16,8 +16,8 @@ def format_filter_for_jinja2(value, format_string=None):
|
||||
|
||||
|
||||
class TextFormatter:
|
||||
def __init__(self):
|
||||
self.env = SandboxedEnvironment(loader=BaseLoader(), autoescape=True)
|
||||
def __init__(self, autoescape: bool = True):
|
||||
self.env = SandboxedEnvironment(loader=BaseLoader(), autoescape=autoescape)
|
||||
self.env.globals.clear()
|
||||
|
||||
# Instead of clearing all filters, just remove potentially unsafe ones
|
||||
|
||||
@@ -2,9 +2,16 @@ import asyncio
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import aioclamd
|
||||
# Suppress the specific pkg_resources deprecation warning from aioclamd
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="pkg_resources is deprecated", category=UserWarning
|
||||
)
|
||||
import aioclamd
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clean the test database by removing all data while preserving the schema.
|
||||
|
||||
Usage:
|
||||
poetry run python clean_test_db.py [--yes]
|
||||
|
||||
Options:
|
||||
--yes Skip confirmation prompt
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
print("=" * 60)
|
||||
print("Cleaning Test Database")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Get initial counts
|
||||
user_count = await db.user.count()
|
||||
agent_count = await db.agentgraph.count()
|
||||
|
||||
print(f"Current data: {user_count} users, {agent_count} agent graphs")
|
||||
|
||||
if user_count == 0 and agent_count == 0:
|
||||
print("Database is already clean!")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
# Check for --yes flag
|
||||
skip_confirm = "--yes" in sys.argv
|
||||
|
||||
if not skip_confirm:
|
||||
response = input("\nDo you want to clean all data? (yes/no): ")
|
||||
if response.lower() != "yes":
|
||||
print("Aborted.")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
print("\nCleaning database...")
|
||||
|
||||
# Delete in reverse order of dependencies
|
||||
tables = [
|
||||
("UserNotificationBatch", db.usernotificationbatch),
|
||||
("NotificationEvent", db.notificationevent),
|
||||
("CreditRefundRequest", db.creditrefundrequest),
|
||||
("StoreListingReview", db.storelistingreview),
|
||||
("StoreListingVersion", db.storelistingversion),
|
||||
("StoreListing", db.storelisting),
|
||||
("AgentNodeExecutionInputOutput", db.agentnodeexecutioninputoutput),
|
||||
("AgentNodeExecution", db.agentnodeexecution),
|
||||
("AgentGraphExecution", db.agentgraphexecution),
|
||||
("AgentNodeLink", db.agentnodelink),
|
||||
("LibraryAgent", db.libraryagent),
|
||||
("AgentPreset", db.agentpreset),
|
||||
("IntegrationWebhook", db.integrationwebhook),
|
||||
("AgentNode", db.agentnode),
|
||||
("AgentGraph", db.agentgraph),
|
||||
("AgentBlock", db.agentblock),
|
||||
("APIKey", db.apikey),
|
||||
("CreditTransaction", db.credittransaction),
|
||||
("AnalyticsMetrics", db.analyticsmetrics),
|
||||
("AnalyticsDetails", db.analyticsdetails),
|
||||
("Profile", db.profile),
|
||||
("UserOnboarding", db.useronboarding),
|
||||
("User", db.user),
|
||||
]
|
||||
|
||||
for table_name, table in tables:
|
||||
try:
|
||||
count = await table.count()
|
||||
if count > 0:
|
||||
await table.delete_many()
|
||||
print(f"✓ Deleted {count} records from {table_name}")
|
||||
except Exception as e:
|
||||
print(f"⚠ Error cleaning {table_name}: {e}")
|
||||
|
||||
# Refresh materialized views (they should be empty now)
|
||||
try:
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
print("\n✓ Refreshed materialized views")
|
||||
except Exception as e:
|
||||
print(f"\n⚠ Could not refresh materialized views: {e}")
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Database cleaned successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
18
autogpt_platform/backend/load-tests/.gitignore
vendored
Normal file
18
autogpt_platform/backend/load-tests/.gitignore
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
# Load testing credentials and sensitive data
|
||||
configs/pre-authenticated-tokens.js
|
||||
configs/k6-credentials.env
|
||||
results/
|
||||
k6-cloud-results.txt
|
||||
|
||||
# Node.js
|
||||
node_modules/
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
.env.local
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
296
autogpt_platform/backend/load-tests/README.md
Normal file
296
autogpt_platform/backend/load-tests/README.md
Normal file
@@ -0,0 +1,296 @@
|
||||
# AutoGPT Platform Load Tests
|
||||
|
||||
Clean, streamlined load testing infrastructure for the AutoGPT Platform using k6.
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Set up Supabase service key (required for token generation)
|
||||
export SUPABASE_SERVICE_KEY="your-supabase-service-key"
|
||||
|
||||
# 2. Generate pre-authenticated tokens (first time setup - creates 160+ tokens with 24-hour expiry)
|
||||
node generate-tokens.js --count=160
|
||||
|
||||
# 3. Set up k6 cloud credentials (for cloud testing - see Credential Setup section below)
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
export K6_CLOUD_PROJECT_ID="4254406"
|
||||
|
||||
# 4. Run orchestrated load tests locally
|
||||
node orchestrator/orchestrator.js DEV local
|
||||
|
||||
# 5. Run orchestrated load tests in k6 cloud (recommended)
|
||||
node orchestrator/orchestrator.js DEV cloud
|
||||
```
|
||||
|
||||
## 📋 Load Test Orchestrator
|
||||
|
||||
The AutoGPT Platform uses a comprehensive load test orchestrator (`orchestrator/orchestrator.js`) that runs 12 optimized tests with maximum VU counts:
|
||||
|
||||
### Available Tests
|
||||
|
||||
#### Basic Tests (Simple validation)
|
||||
|
||||
- **connectivity-test**: Basic connectivity and authentication validation
|
||||
- **single-endpoint-test**: Individual API endpoint testing with high concurrency
|
||||
|
||||
#### API Tests (Core functionality)
|
||||
|
||||
- **core-api-test**: Core API endpoints (`/api/credits`, `/api/graphs`, `/api/blocks`, `/api/executions`)
|
||||
- **graph-execution-test**: Complete graph creation and execution pipeline
|
||||
|
||||
#### Marketplace Tests (User-facing features)
|
||||
|
||||
- **marketplace-public-test**: Public marketplace browsing and search
|
||||
- **marketplace-library-test**: Authenticated marketplace and user library operations
|
||||
|
||||
#### Comprehensive Tests (End-to-end scenarios)
|
||||
|
||||
- **comprehensive-test**: Complete user journey simulation with multiple operations
|
||||
|
||||
### Test Modes
|
||||
|
||||
- **Local Mode**: 5 VUs × 30s - Quick validation and debugging
|
||||
- **Cloud Mode**: 80-160 VUs × 3-6m - Real performance testing
|
||||
|
||||
## 🛠️ Usage
|
||||
|
||||
### Basic Commands
|
||||
|
||||
```bash
|
||||
# Run 12 optimized tests locally (for debugging)
|
||||
node orchestrator/orchestrator.js DEV local
|
||||
|
||||
# Run 12 optimized tests in k6 cloud (recommended for performance testing)
|
||||
node orchestrator/orchestrator.js DEV cloud
|
||||
|
||||
# Run against production (coordinate with team!)
|
||||
node orchestrator/orchestrator.js PROD cloud
|
||||
|
||||
# Run individual test directly with k6
|
||||
K6_ENVIRONMENT=DEV VUS=100 DURATION=3m k6 run tests/api/core-api-test.js
|
||||
```
|
||||
|
||||
### NPM Scripts
|
||||
|
||||
```bash
|
||||
# Run orchestrator locally
|
||||
npm run local
|
||||
|
||||
# Run orchestrator in k6 cloud
|
||||
npm run cloud
|
||||
```
|
||||
|
||||
## 🔧 Test Configuration
|
||||
|
||||
### Pre-Authenticated Tokens
|
||||
|
||||
- **Generation**: Run `node generate-tokens.js --count=160` to create tokens
|
||||
- **File**: `configs/pre-authenticated-tokens.js` (gitignored for security)
|
||||
- **Capacity**: 160+ tokens supporting high-concurrency testing
|
||||
- **Expiry**: 24 hours (86400 seconds) - extended for long-duration testing
|
||||
- **Benefit**: Eliminates Supabase auth rate limiting at scale
|
||||
- **Regeneration**: Run `node generate-tokens.js --count=160` when tokens expire after 24 hours
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
- **LOCAL**: `http://localhost:8006` (local development)
|
||||
- **DEV**: `https://dev-api.agpt.co` (development environment)
|
||||
- **PROD**: `https://api.agpt.co` (production environment - coordinate with team!)
|
||||
|
||||
## 📊 Performance Testing Features
|
||||
|
||||
### Real-Time Monitoring
|
||||
|
||||
- **k6 Cloud Dashboard**: Live performance metrics during cloud test execution
|
||||
- **URL Tracking**: Test URLs automatically saved to `k6-cloud-results.txt`
|
||||
- **Error Tracking**: Detailed failure analysis and HTTP status monitoring
|
||||
- **Custom Metrics**: Request success/failure rates, response times, user journey tracking
|
||||
- **Authentication Monitoring**: Tracks auth success/failure rates separately from HTTP errors
|
||||
|
||||
### Load Testing Capabilities
|
||||
|
||||
- **High Concurrency**: Up to 160+ virtual users per test
|
||||
- **Authentication Scaling**: Pre-auth tokens support 160+ concurrent users
|
||||
- **Sequential Execution**: Multiple tests run one after another with proper delays
|
||||
- **Cloud Infrastructure**: Tests run on k6 cloud servers for consistent results
|
||||
- **ES Module Support**: Full ES module compatibility with modern JavaScript features
|
||||
|
||||
## 📈 Performance Expectations
|
||||
|
||||
### Validated Performance Limits
|
||||
|
||||
- **Core API**: 100+ VUs successfully handling `/api/credits`, `/api/graphs`, `/api/blocks`, `/api/executions`
|
||||
- **Graph Execution**: 80+ VUs for complete workflow pipeline
|
||||
- **Marketplace Browsing**: 160 VUs for public marketplace access (verified)
|
||||
- **Marketplace Library**: 160 VUs for authenticated library operations (verified)
|
||||
- **Authentication**: 160+ concurrent users with pre-authenticated tokens
|
||||
|
||||
### Target Metrics
|
||||
|
||||
- **P95 Latency**: Target < 5 seconds (marketplace), < 2 seconds (core API)
|
||||
- **P99 Latency**: Target < 10 seconds (marketplace), < 5 seconds (core API)
|
||||
- **Success Rate**: Target > 95% under normal load
|
||||
- **Error Rate**: Target < 5% for all endpoints
|
||||
|
||||
### Recent Performance Results (160 VU Test - Verified)
|
||||
|
||||
- **Marketplace Library Operations**: 500-1000ms response times at 160 VUs
|
||||
- **Authentication**: 100% success rate with pre-authenticated tokens
|
||||
- **Library Journeys**: 5 operations per journey completing successfully
|
||||
- **Test Duration**: 6+ minutes sustained load without degradation
|
||||
- **k6 Cloud Execution**: Stable performance on Amazon US Columbus infrastructure
|
||||
|
||||
## 🔍 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**1. Authentication Failures**
|
||||
|
||||
```
|
||||
❌ No valid authentication token available
|
||||
❌ Token has expired
|
||||
```
|
||||
|
||||
- **Solution**: Run `node generate-tokens.js --count=160` to create fresh 24-hour tokens
|
||||
- **Note**: Use `--count` parameter to generate appropriate number of tokens for your test scale
|
||||
|
||||
**2. Cloud Credentials Missing**
|
||||
|
||||
```
|
||||
❌ Missing k6 cloud credentials
|
||||
```
|
||||
|
||||
- **Solution**: Set `K6_CLOUD_TOKEN` and `K6_CLOUD_PROJECT_ID=4254406`
|
||||
|
||||
**3. k6 Cloud VU Scaling Issue**
|
||||
|
||||
```
|
||||
❌ Test shows only 5 VUs instead of requested 100+ VUs
|
||||
```
|
||||
|
||||
- **Problem**: Using `K6_ENVIRONMENT=DEV VUS=160 k6 cloud run test.js` (incorrect)
|
||||
- **Solution**: Use `k6 cloud run --env K6_ENVIRONMENT=DEV --env VUS=160 test.js` (correct)
|
||||
- **Note**: The unified test runner (`run-tests.js`) already uses the correct syntax
|
||||
|
||||
**4. Setup Verification Failed**
|
||||
|
||||
```
|
||||
❌ Verification failed
|
||||
```
|
||||
|
||||
- **Solution**: Check tokens exist and local API is accessible
|
||||
|
||||
### Required Setup
|
||||
|
||||
**1. Supabase Service Key (Required for all testing):**
|
||||
|
||||
```bash
|
||||
# Option 1: From your local environment (if available)
|
||||
export SUPABASE_SERVICE_KEY="your-supabase-service-key"
|
||||
|
||||
# Option 2: From Kubernetes secret (for platform developers)
|
||||
kubectl get secret supabase-service-key -o jsonpath='{.data.service-key}' | base64 -d
|
||||
|
||||
# Option 3: From Supabase dashboard
|
||||
# Go to Project Settings > API > service_role key (never commit this!)
|
||||
```
|
||||
|
||||
**2. Generate Pre-Authenticated Tokens (Required):**
|
||||
|
||||
```bash
|
||||
# Creates 160 tokens with 24-hour expiry - prevents auth rate limiting
|
||||
node generate-tokens.js --count=160
|
||||
|
||||
# Generate fewer tokens for smaller tests (minimum 10)
|
||||
node generate-tokens.js --count=50
|
||||
|
||||
# Regenerate when tokens expire (every 24 hours)
|
||||
node generate-tokens.js --count=160
|
||||
```
|
||||
|
||||
**3. k6 Cloud Credentials (Required for cloud testing):**
|
||||
|
||||
```bash
|
||||
# Get from k6 cloud dashboard: https://app.k6.io/account/api-token
|
||||
export K6_CLOUD_TOKEN="your-k6-cloud-token"
|
||||
export K6_CLOUD_PROJECT_ID="4254406" # AutoGPT Platform project ID
|
||||
|
||||
# Verify credentials work by running orchestrator
|
||||
node orchestrator/orchestrator.js DEV cloud
|
||||
```
|
||||
|
||||
## 📂 File Structure
|
||||
|
||||
```
|
||||
load-tests/
|
||||
├── README.md # This documentation
|
||||
├── generate-tokens.js # Generate pre-auth tokens (MAIN TOKEN SETUP)
|
||||
├── package.json # Node.js dependencies and scripts
|
||||
├── orchestrator/
|
||||
│ └── orchestrator.js # Main test orchestrator (MAIN ENTRY POINT)
|
||||
├── configs/
|
||||
│ ├── environment.js # Environment URLs and configuration
|
||||
│ └── pre-authenticated-tokens.js # Generated tokens (gitignored)
|
||||
├── tests/
|
||||
│ ├── basic/
|
||||
│ │ ├── connectivity-test.js # Basic connectivity validation
|
||||
│ │ └── single-endpoint-test.js # Individual API endpoint testing
|
||||
│ ├── api/
|
||||
│ │ ├── core-api-test.js # Core authenticated API endpoints
|
||||
│ │ └── graph-execution-test.js # Graph workflow pipeline testing
|
||||
│ ├── marketplace/
|
||||
│ │ ├── public-access-test.js # Public marketplace browsing
|
||||
│ │ └── library-access-test.js # Authenticated marketplace/library
|
||||
│ └── comprehensive/
|
||||
│ └── platform-journey-test.js # Complete user journey simulation
|
||||
├── results/ # Local test results (auto-created)
|
||||
├── unified-results-*.json # Orchestrator results (auto-created)
|
||||
└── *.log # Test execution logs (auto-created)
|
||||
```
|
||||
|
||||
## 🎯 Best Practices
|
||||
|
||||
1. **Generate Tokens First**: Always run `node generate-tokens.js --count=160` before testing
|
||||
2. **Local for Development**: Use `DEV local` for debugging and development
|
||||
3. **Cloud for Performance**: Use `DEV cloud` for actual performance testing
|
||||
4. **Monitor Real-Time**: Check k6 cloud dashboards during test execution
|
||||
5. **Regenerate Tokens**: Refresh tokens every 24 hours when they expire
|
||||
6. **Unified Testing**: Orchestrator runs 12 optimized tests automatically
|
||||
|
||||
## 🚀 Advanced Usage
|
||||
|
||||
### Direct k6 Execution
|
||||
|
||||
For granular control over individual test scripts:
|
||||
|
||||
```bash
|
||||
# k6 Cloud execution (recommended for performance testing)
|
||||
# IMPORTANT: Use --env syntax for k6 cloud to ensure proper VU scaling
|
||||
k6 cloud run --env K6_ENVIRONMENT=DEV --env VUS=160 --env DURATION=5m --env RAMP_UP=30s --env RAMP_DOWN=30s tests/marketplace/library-access-test.js
|
||||
|
||||
# Local execution with cloud output (debugging)
|
||||
K6_ENVIRONMENT=DEV VUS=10 DURATION=1m \
|
||||
k6 run tests/api/core-api-test.js --out cloud
|
||||
|
||||
# Local execution with JSON output (offline testing)
|
||||
K6_ENVIRONMENT=DEV VUS=10 DURATION=1m \
|
||||
k6 run tests/api/core-api-test.js --out json=results.json
|
||||
```
|
||||
|
||||
### Custom Token Generation
|
||||
|
||||
```bash
|
||||
# Generate specific number of tokens
|
||||
node generate-tokens.js --count=200
|
||||
|
||||
# Generate tokens with custom timeout
|
||||
node generate-tokens.js --count=100 --timeout=60
|
||||
```
|
||||
|
||||
## 🔗 Related Documentation
|
||||
|
||||
- [k6 Documentation](https://k6.io/docs/)
|
||||
- [AutoGPT Platform API Documentation](https://docs.agpt.co/)
|
||||
- [k6 Cloud Dashboard](https://significantgravitas.grafana.net/a/k6-app/)
|
||||
|
||||
For questions or issues, please refer to the [AutoGPT Platform issues](https://github.com/Significant-Gravitas/AutoGPT/issues).
|
||||
141
autogpt_platform/backend/load-tests/configs/environment.js
Normal file
141
autogpt_platform/backend/load-tests/configs/environment.js
Normal file
@@ -0,0 +1,141 @@
|
||||
// Environment configuration for AutoGPT Platform load tests
|
||||
export const ENV_CONFIG = {
|
||||
DEV: {
|
||||
API_BASE_URL: "https://dev-server.agpt.co",
|
||||
BUILDER_BASE_URL: "https://dev-builder.agpt.co",
|
||||
WS_BASE_URL: "wss://dev-ws-server.agpt.co",
|
||||
SUPABASE_URL: "https://adfjtextkuilwuhzdjpf.supabase.co",
|
||||
SUPABASE_ANON_KEY:
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFkZmp0ZXh0a3VpbHd1aHpkanBmIiwicm9sZSI6ImFub24iLCJpYXQiOjE3MzAyNTE3MDIsImV4cCI6MjA0NTgyNzcwMn0.IuQNXsHEKJNxtS9nyFeqO0BGMYN8sPiObQhuJLSK9xk",
|
||||
},
|
||||
LOCAL: {
|
||||
API_BASE_URL: "http://localhost:8006",
|
||||
BUILDER_BASE_URL: "http://localhost:3000",
|
||||
WS_BASE_URL: "ws://localhost:8001",
|
||||
SUPABASE_URL: "http://localhost:8000",
|
||||
SUPABASE_ANON_KEY:
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE",
|
||||
},
|
||||
PROD: {
|
||||
API_BASE_URL: "https://api.agpt.co",
|
||||
BUILDER_BASE_URL: "https://builder.agpt.co",
|
||||
WS_BASE_URL: "wss://ws-server.agpt.co",
|
||||
SUPABASE_URL: "https://supabase.agpt.co",
|
||||
SUPABASE_ANON_KEY:
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImJnd3B3ZHN4YmxyeWloaW51dGJ4Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3MzAyODYzMDUsImV4cCI6MjA0NTg2MjMwNX0.ISa2IofTdQIJmmX5JwKGGNajqjsD8bjaGBzK90SubE0",
|
||||
},
|
||||
};
|
||||
|
||||
// Get environment config based on K6_ENVIRONMENT variable (default: DEV)
|
||||
export function getEnvironmentConfig() {
|
||||
const env = __ENV.K6_ENVIRONMENT || "DEV";
|
||||
return ENV_CONFIG[env];
|
||||
}
|
||||
|
||||
// Authentication configuration
|
||||
export const AUTH_CONFIG = {
|
||||
// Test user credentials - REPLACE WITH ACTUAL TEST ACCOUNTS
|
||||
TEST_USERS: [
|
||||
{
|
||||
email: "loadtest1@example.com",
|
||||
password: "LoadTest123!",
|
||||
user_id: "test-user-1",
|
||||
},
|
||||
{
|
||||
email: "loadtest2@example.com",
|
||||
password: "LoadTest123!",
|
||||
user_id: "test-user-2",
|
||||
},
|
||||
{
|
||||
email: "loadtest3@example.com",
|
||||
password: "LoadTest123!",
|
||||
user_id: "test-user-3",
|
||||
},
|
||||
],
|
||||
|
||||
// JWT token for API access (will be set during test execution)
|
||||
JWT_TOKEN: null,
|
||||
};
|
||||
|
||||
// Performance test configurations - Environment variable overrides supported
|
||||
export const PERFORMANCE_CONFIG = {
|
||||
// Default load test parameters (override with env vars: VUS, DURATION, RAMP_UP, RAMP_DOWN)
|
||||
DEFAULT_VUS: parseInt(__ENV.VUS) || 10,
|
||||
DEFAULT_DURATION: __ENV.DURATION || "2m",
|
||||
DEFAULT_RAMP_UP: __ENV.RAMP_UP || "30s",
|
||||
DEFAULT_RAMP_DOWN: __ENV.RAMP_DOWN || "30s",
|
||||
|
||||
// Stress test parameters (override with env vars: STRESS_VUS, STRESS_DURATION, etc.)
|
||||
STRESS_VUS: parseInt(__ENV.STRESS_VUS) || 50,
|
||||
STRESS_DURATION: __ENV.STRESS_DURATION || "5m",
|
||||
STRESS_RAMP_UP: __ENV.STRESS_RAMP_UP || "1m",
|
||||
STRESS_RAMP_DOWN: __ENV.STRESS_RAMP_DOWN || "1m",
|
||||
|
||||
// Spike test parameters (override with env vars: SPIKE_VUS, SPIKE_DURATION, etc.)
|
||||
SPIKE_VUS: parseInt(__ENV.SPIKE_VUS) || 100,
|
||||
SPIKE_DURATION: __ENV.SPIKE_DURATION || "30s",
|
||||
SPIKE_RAMP_UP: __ENV.SPIKE_RAMP_UP || "10s",
|
||||
SPIKE_RAMP_DOWN: __ENV.SPIKE_RAMP_DOWN || "10s",
|
||||
|
||||
// Volume test parameters (override with env vars: VOLUME_VUS, VOLUME_DURATION, etc.)
|
||||
VOLUME_VUS: parseInt(__ENV.VOLUME_VUS) || 20,
|
||||
VOLUME_DURATION: __ENV.VOLUME_DURATION || "10m",
|
||||
VOLUME_RAMP_UP: __ENV.VOLUME_RAMP_UP || "2m",
|
||||
VOLUME_RAMP_DOWN: __ENV.VOLUME_RAMP_DOWN || "2m",
|
||||
|
||||
// SLA thresholds (adjustable via env vars: THRESHOLD_P95, THRESHOLD_P99, etc.)
|
||||
THRESHOLDS: {
|
||||
http_req_duration: [
|
||||
`p(95)<${__ENV.THRESHOLD_P95 || "2000"}`,
|
||||
`p(99)<${__ENV.THRESHOLD_P99 || "5000"}`,
|
||||
],
|
||||
http_req_failed: [`rate<${__ENV.THRESHOLD_ERROR_RATE || "0.05"}`],
|
||||
http_reqs: [`rate>${__ENV.THRESHOLD_RPS || "10"}`],
|
||||
checks: [`rate>${__ENV.THRESHOLD_CHECK_RATE || "0.95"}`],
|
||||
},
|
||||
};
|
||||
|
||||
// Helper function to get load test configuration based on test type
|
||||
export function getLoadTestConfig(testType = "default") {
|
||||
const configs = {
|
||||
default: {
|
||||
vus: PERFORMANCE_CONFIG.DEFAULT_VUS,
|
||||
duration: PERFORMANCE_CONFIG.DEFAULT_DURATION,
|
||||
rampUp: PERFORMANCE_CONFIG.DEFAULT_RAMP_UP,
|
||||
rampDown: PERFORMANCE_CONFIG.DEFAULT_RAMP_DOWN,
|
||||
},
|
||||
stress: {
|
||||
vus: PERFORMANCE_CONFIG.STRESS_VUS,
|
||||
duration: PERFORMANCE_CONFIG.STRESS_DURATION,
|
||||
rampUp: PERFORMANCE_CONFIG.STRESS_RAMP_UP,
|
||||
rampDown: PERFORMANCE_CONFIG.STRESS_RAMP_DOWN,
|
||||
},
|
||||
spike: {
|
||||
vus: PERFORMANCE_CONFIG.SPIKE_VUS,
|
||||
duration: PERFORMANCE_CONFIG.SPIKE_DURATION,
|
||||
rampUp: PERFORMANCE_CONFIG.SPIKE_RAMP_UP,
|
||||
rampDown: PERFORMANCE_CONFIG.SPIKE_RAMP_DOWN,
|
||||
},
|
||||
volume: {
|
||||
vus: PERFORMANCE_CONFIG.VOLUME_VUS,
|
||||
duration: PERFORMANCE_CONFIG.VOLUME_DURATION,
|
||||
rampUp: PERFORMANCE_CONFIG.VOLUME_RAMP_UP,
|
||||
rampDown: PERFORMANCE_CONFIG.VOLUME_RAMP_DOWN,
|
||||
},
|
||||
};
|
||||
|
||||
return configs[testType] || configs.default;
|
||||
}
|
||||
|
||||
// Grafana Cloud K6 configuration
|
||||
export const GRAFANA_CONFIG = {
|
||||
PROJECT_ID: __ENV.K6_CLOUD_PROJECT_ID || "",
|
||||
TOKEN: __ENV.K6_CLOUD_TOKEN || "",
|
||||
// Tags for organizing test results
|
||||
TEST_TAGS: {
|
||||
team: "platform",
|
||||
service: "autogpt-platform",
|
||||
environment: __ENV.K6_ENVIRONMENT || "dev",
|
||||
version: __ENV.GIT_COMMIT || "unknown",
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,9 @@
|
||||
# k6 Cloud Credentials (EXAMPLE FILE)
|
||||
# Copy this to k6-credentials.env and fill in your actual credentials
|
||||
#
|
||||
# Get these from: https://app.k6.io/
|
||||
# - K6_CLOUD_TOKEN: Your k6 cloud API token
|
||||
# - K6_CLOUD_PROJECT_ID: Your project ID
|
||||
|
||||
K6_CLOUD_TOKEN=your-k6-cloud-token-here
|
||||
K6_CLOUD_PROJECT_ID=your-project-id-here
|
||||
@@ -0,0 +1,51 @@
|
||||
// Pre-authenticated tokens for load testing (EXAMPLE FILE)
|
||||
// Copy this to pre-authenticated-tokens.js and run generate-tokens.js to populate
|
||||
//
|
||||
// ⚠️ SECURITY: The real file contains authentication tokens
|
||||
// ⚠️ DO NOT COMMIT TO GIT - Real file is gitignored
|
||||
|
||||
export const PRE_AUTHENTICATED_TOKENS = [
|
||||
// Will be populated by generate-tokens.js with 350+ real tokens
|
||||
// Example structure:
|
||||
// {
|
||||
// token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
// user: "loadtest4@example.com",
|
||||
// generated: "2025-01-24T10:08:04.123Z",
|
||||
// round: 1
|
||||
// }
|
||||
];
|
||||
|
||||
export function getPreAuthenticatedToken(vuId = 1) {
|
||||
if (PRE_AUTHENTICATED_TOKENS.length === 0) {
|
||||
throw new Error(
|
||||
"No pre-authenticated tokens available. Run: node generate-tokens.js",
|
||||
);
|
||||
}
|
||||
|
||||
const tokenIndex = (vuId - 1) % PRE_AUTHENTICATED_TOKENS.length;
|
||||
const tokenData = PRE_AUTHENTICATED_TOKENS[tokenIndex];
|
||||
|
||||
return {
|
||||
access_token: tokenData.token,
|
||||
user: { email: tokenData.user },
|
||||
generated: tokenData.generated,
|
||||
};
|
||||
}
|
||||
|
||||
export function getPreAuthenticatedHeaders(vuId = 1) {
|
||||
const authData = getPreAuthenticatedToken(vuId);
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${authData.access_token}`,
|
||||
};
|
||||
}
|
||||
|
||||
export const TOKEN_STATS = {
|
||||
total: PRE_AUTHENTICATED_TOKENS.length,
|
||||
users: [...new Set(PRE_AUTHENTICATED_TOKENS.map((t) => t.user))].length,
|
||||
generated: PRE_AUTHENTICATED_TOKENS[0]?.generated || "unknown",
|
||||
};
|
||||
|
||||
console.log(
|
||||
`🔐 Loaded ${TOKEN_STATS.total} pre-authenticated tokens from ${TOKEN_STATS.users} users`,
|
||||
);
|
||||
236
autogpt_platform/backend/load-tests/generate-tokens.js
Normal file
236
autogpt_platform/backend/load-tests/generate-tokens.js
Normal file
@@ -0,0 +1,236 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
/**
|
||||
* Generate Pre-Authenticated Tokens for Load Testing
|
||||
* Creates configs/pre-authenticated-tokens.js with 350+ tokens
|
||||
*
|
||||
* This replaces the old token generation scripts with a clean, single script
|
||||
*/
|
||||
|
||||
import https from "https";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
|
||||
// Get Supabase service key from environment (REQUIRED for token generation)
|
||||
const SUPABASE_SERVICE_KEY = process.env.SUPABASE_SERVICE_KEY;
|
||||
|
||||
if (!SUPABASE_SERVICE_KEY) {
|
||||
console.error("❌ SUPABASE_SERVICE_KEY environment variable is required");
|
||||
console.error("Get service key from kubectl or environment:");
|
||||
console.error('export SUPABASE_SERVICE_KEY="your-service-key"');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Generate test users (loadtest4-50 are known to work)
|
||||
const TEST_USERS = [];
|
||||
for (let i = 4; i <= 50; i++) {
|
||||
TEST_USERS.push({
|
||||
email: `loadtest${i}@example.com`,
|
||||
password: "password123",
|
||||
});
|
||||
}
|
||||
|
||||
console.log(
|
||||
`🔐 Generating pre-authenticated tokens from ${TEST_USERS.length} users...`,
|
||||
);
|
||||
|
||||
async function authenticateUser(user, attempt = 1) {
|
||||
return new Promise((resolve) => {
|
||||
const postData = JSON.stringify({
|
||||
email: user.email,
|
||||
password: user.password,
|
||||
expires_in: 86400, // 24 hours in seconds (24 * 60 * 60)
|
||||
});
|
||||
|
||||
const options = {
|
||||
hostname: "adfjtextkuilwuhzdjpf.supabase.co",
|
||||
path: "/auth/v1/token?grant_type=password",
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${SUPABASE_SERVICE_KEY}`,
|
||||
apikey: SUPABASE_SERVICE_KEY,
|
||||
"Content-Type": "application/json",
|
||||
"Content-Length": postData.length,
|
||||
},
|
||||
};
|
||||
|
||||
const req = https.request(options, (res) => {
|
||||
let data = "";
|
||||
res.on("data", (chunk) => (data += chunk));
|
||||
res.on("end", () => {
|
||||
try {
|
||||
if (res.statusCode === 200) {
|
||||
const authData = JSON.parse(data);
|
||||
resolve(authData.access_token);
|
||||
} else if (res.statusCode === 429) {
|
||||
// Rate limited - wait and retry
|
||||
console.log(
|
||||
`⏳ Rate limited for ${user.email}, waiting 5s (attempt ${attempt}/3)...`,
|
||||
);
|
||||
setTimeout(() => {
|
||||
if (attempt < 3) {
|
||||
authenticateUser(user, attempt + 1).then(resolve);
|
||||
} else {
|
||||
console.log(`❌ Max retries exceeded for ${user.email}`);
|
||||
resolve(null);
|
||||
}
|
||||
}, 5000);
|
||||
} else {
|
||||
console.log(`❌ Auth failed for ${user.email}: ${res.statusCode}`);
|
||||
resolve(null);
|
||||
}
|
||||
} catch (e) {
|
||||
console.log(`❌ Parse error for ${user.email}:`, e.message);
|
||||
resolve(null);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
req.on("error", (err) => {
|
||||
console.log(`❌ Request error for ${user.email}:`, err.message);
|
||||
resolve(null);
|
||||
});
|
||||
|
||||
req.write(postData);
|
||||
req.end();
|
||||
});
|
||||
}
|
||||
|
||||
async function generateTokens() {
|
||||
console.log("🚀 Starting token generation...");
|
||||
console.log("Rate limit aware - this will take ~10-15 minutes");
|
||||
console.log("===========================================\n");
|
||||
|
||||
const tokens = [];
|
||||
const startTime = Date.now();
|
||||
|
||||
// Generate tokens - configurable via --count argument or default to 150
|
||||
const targetTokens =
|
||||
parseInt(
|
||||
process.argv.find((arg) => arg.startsWith("--count="))?.split("=")[1],
|
||||
) ||
|
||||
parseInt(process.env.TOKEN_COUNT) ||
|
||||
150;
|
||||
const tokensPerUser = Math.ceil(targetTokens / TEST_USERS.length);
|
||||
console.log(
|
||||
`📊 Generating ${tokensPerUser} tokens per user (${TEST_USERS.length} users) - Target: ${targetTokens}\n`,
|
||||
);
|
||||
|
||||
for (let round = 1; round <= tokensPerUser; round++) {
|
||||
console.log(`🔄 Round ${round}/${tokensPerUser}:`);
|
||||
|
||||
for (
|
||||
let i = 0;
|
||||
i < TEST_USERS.length && tokens.length < targetTokens;
|
||||
i++
|
||||
) {
|
||||
const user = TEST_USERS[i];
|
||||
|
||||
process.stdout.write(` ${user.email.padEnd(25)} ... `);
|
||||
|
||||
const token = await authenticateUser(user);
|
||||
|
||||
if (token) {
|
||||
tokens.push({
|
||||
token,
|
||||
user: user.email,
|
||||
generated: new Date().toISOString(),
|
||||
round: round,
|
||||
});
|
||||
console.log(`✅ (${tokens.length}/${targetTokens})`);
|
||||
} else {
|
||||
console.log(`❌`);
|
||||
}
|
||||
|
||||
// Respect rate limits - wait 500ms between requests
|
||||
if (tokens.length < targetTokens) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
}
|
||||
}
|
||||
|
||||
if (tokens.length >= targetTokens) break;
|
||||
|
||||
// Wait longer between rounds
|
||||
if (round < tokensPerUser) {
|
||||
console.log(` ⏸️ Waiting 3s before next round...\n`);
|
||||
await new Promise((resolve) => setTimeout(resolve, 3000));
|
||||
}
|
||||
}
|
||||
|
||||
const duration = Math.round((Date.now() - startTime) / 1000);
|
||||
console.log(`\n✅ Generated ${tokens.length} tokens in ${duration}s`);
|
||||
|
||||
// Create configs directory if it doesn't exist
|
||||
const configsDir = path.join(process.cwd(), "configs");
|
||||
if (!fs.existsSync(configsDir)) {
|
||||
fs.mkdirSync(configsDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Write tokens to secure file
|
||||
const jsContent = `// Pre-authenticated tokens for load testing
|
||||
// Generated: ${new Date().toISOString()}
|
||||
// Total tokens: ${tokens.length}
|
||||
// Generation time: ${duration} seconds
|
||||
//
|
||||
// ⚠️ SECURITY: This file contains real authentication tokens
|
||||
// ⚠️ DO NOT COMMIT TO GIT - File is gitignored
|
||||
|
||||
export const PRE_AUTHENTICATED_TOKENS = ${JSON.stringify(tokens, null, 2)};
|
||||
|
||||
export function getPreAuthenticatedToken(vuId = 1) {
|
||||
if (PRE_AUTHENTICATED_TOKENS.length === 0) {
|
||||
throw new Error('No pre-authenticated tokens available');
|
||||
}
|
||||
|
||||
const tokenIndex = (vuId - 1) % PRE_AUTHENTICATED_TOKENS.length;
|
||||
const tokenData = PRE_AUTHENTICATED_TOKENS[tokenIndex];
|
||||
|
||||
return {
|
||||
access_token: tokenData.token,
|
||||
user: { email: tokenData.user },
|
||||
generated: tokenData.generated
|
||||
};
|
||||
}
|
||||
|
||||
// Generate single session ID for this test run
|
||||
const LOAD_TEST_SESSION_ID = '${new Date().toISOString().slice(0, 16).replace(/:/g, "-")}-' + Math.random().toString(36).substr(2, 8);
|
||||
|
||||
export function getPreAuthenticatedHeaders(vuId = 1) {
|
||||
const authData = getPreAuthenticatedToken(vuId);
|
||||
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': \`Bearer \${authData.access_token}\`,
|
||||
'X-Load-Test-Session': LOAD_TEST_SESSION_ID,
|
||||
'X-Load-Test-VU': vuId.toString(),
|
||||
'X-Load-Test-User': authData.user.email,
|
||||
};
|
||||
}
|
||||
|
||||
export const TOKEN_STATS = {
|
||||
total: PRE_AUTHENTICATED_TOKENS.length,
|
||||
users: [...new Set(PRE_AUTHENTICATED_TOKENS.map(t => t.user))].length,
|
||||
generated: PRE_AUTHENTICATED_TOKENS[0]?.generated || 'unknown'
|
||||
};
|
||||
|
||||
console.log(\`🔐 Loaded \${TOKEN_STATS.total} pre-authenticated tokens from \${TOKEN_STATS.users} users\`);
|
||||
`;
|
||||
|
||||
const tokenFile = path.join(configsDir, "pre-authenticated-tokens.js");
|
||||
fs.writeFileSync(tokenFile, jsContent);
|
||||
|
||||
console.log(`💾 Saved to configs/pre-authenticated-tokens.js`);
|
||||
console.log(`🚀 Ready for ${tokens.length} concurrent VU load testing!`);
|
||||
console.log(
|
||||
`\n🔒 Security Note: Token file is gitignored and will not be committed`,
|
||||
);
|
||||
|
||||
return tokens.length;
|
||||
}
|
||||
|
||||
// Run if called directly
|
||||
if (process.argv[1] === new URL(import.meta.url).pathname) {
|
||||
generateTokens().catch(console.error);
|
||||
}
|
||||
|
||||
export { generateTokens };
|
||||
362
autogpt_platform/backend/load-tests/orchestrator/orchestrator.js
Normal file
362
autogpt_platform/backend/load-tests/orchestrator/orchestrator.js
Normal file
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
/**
|
||||
* AutoGPT Platform Load Test Orchestrator
|
||||
*
|
||||
* Optimized test suite with only the highest VU count for each unique test type.
|
||||
* Eliminates duplicate tests and focuses on maximum load testing.
|
||||
*/
|
||||
|
||||
import { spawn } from 'child_process';
|
||||
import fs from 'fs';
|
||||
|
||||
console.log("🎯 AUTOGPT PLATFORM LOAD TEST ORCHESTRATOR\n");
|
||||
console.log("===========================================\n");
|
||||
|
||||
// Parse command line arguments
|
||||
const args = process.argv.slice(2);
|
||||
const environment = args[0] || "DEV"; // LOCAL, DEV, PROD
|
||||
const executionMode = args[1] || "cloud"; // local, cloud
|
||||
|
||||
console.log(`🌍 Target Environment: ${environment}`);
|
||||
console.log(`🚀 Execution Mode: ${executionMode}`);
|
||||
|
||||
// Unified test scenarios - only highest VUs for each unique test
|
||||
const unifiedTestScenarios = [
|
||||
// 1. Marketplace Public Access (highest VUs: 314)
|
||||
{
|
||||
name: "Marketplace_Public_Access_Max_Load",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 314,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
description: "Public marketplace browsing at maximum load"
|
||||
},
|
||||
|
||||
// 2. Marketplace Authenticated Access (highest VUs: 157)
|
||||
{
|
||||
name: "Marketplace_Authenticated_Access_Max_Load",
|
||||
file: "tests/marketplace/library-access-test.js",
|
||||
vus: 157,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
description: "Authenticated marketplace/library operations at maximum load"
|
||||
},
|
||||
|
||||
// 3. Core API Load Test (highest VUs: 100)
|
||||
{
|
||||
name: "Core_API_Max_Load",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 100,
|
||||
duration: "5m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Core authenticated API endpoints at maximum load"
|
||||
},
|
||||
|
||||
// 4. Graph Execution Load Test (highest VUs: 100)
|
||||
{
|
||||
name: "Graph_Execution_Max_Load",
|
||||
file: "tests/api/graph-execution-test.js",
|
||||
vus: 100,
|
||||
duration: "5m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Graph workflow execution pipeline at maximum load"
|
||||
},
|
||||
|
||||
// 5. Credits API Single Endpoint (upgraded to 100 VUs)
|
||||
{
|
||||
name: "Credits_API_Max_Load",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
env: { ENDPOINT: "credits", CONCURRENT_REQUESTS: "1" },
|
||||
description: "Credits API endpoint at maximum load"
|
||||
},
|
||||
|
||||
// 6. Graphs API Single Endpoint (upgraded to 100 VUs)
|
||||
{
|
||||
name: "Graphs_API_Max_Load",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
env: { ENDPOINT: "graphs", CONCURRENT_REQUESTS: "1" },
|
||||
description: "Graphs API endpoint at maximum load"
|
||||
},
|
||||
|
||||
// 7. Blocks API Single Endpoint (upgraded to 100 VUs)
|
||||
{
|
||||
name: "Blocks_API_Max_Load",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
env: { ENDPOINT: "blocks", CONCURRENT_REQUESTS: "1" },
|
||||
description: "Blocks API endpoint at maximum load"
|
||||
},
|
||||
|
||||
// 8. Executions API Single Endpoint (upgraded to 100 VUs)
|
||||
{
|
||||
name: "Executions_API_Max_Load",
|
||||
file: "tests/basic/single-endpoint-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
env: { ENDPOINT: "executions", CONCURRENT_REQUESTS: "1" },
|
||||
description: "Executions API endpoint at maximum load"
|
||||
},
|
||||
|
||||
// 9. Comprehensive Platform Journey (highest VUs: 100)
|
||||
{
|
||||
name: "Comprehensive_Platform_Max_Load",
|
||||
file: "tests/comprehensive/platform-journey-test.js",
|
||||
vus: 100,
|
||||
duration: "3m",
|
||||
rampUp: "30s",
|
||||
rampDown: "30s",
|
||||
description: "End-to-end user journey simulation at maximum load"
|
||||
},
|
||||
|
||||
// 10. Marketplace Stress Test (highest VUs: 500)
|
||||
{
|
||||
name: "Marketplace_Stress_Test",
|
||||
file: "tests/marketplace/public-access-test.js",
|
||||
vus: 500,
|
||||
duration: "2m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Ultimate marketplace stress test"
|
||||
},
|
||||
|
||||
// 11. Core API Stress Test (highest VUs: 500)
|
||||
{
|
||||
name: "Core_API_Stress_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 500,
|
||||
duration: "2m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Ultimate core API stress test"
|
||||
},
|
||||
|
||||
// 12. Long Duration Core API Test (highest VUs: 100, longest duration)
|
||||
{
|
||||
name: "Long_Duration_Core_API_Test",
|
||||
file: "tests/api/core-api-test.js",
|
||||
vus: 100,
|
||||
duration: "10m",
|
||||
rampUp: "1m",
|
||||
rampDown: "1m",
|
||||
description: "Extended duration core API endurance test"
|
||||
}
|
||||
];
|
||||
|
||||
// Configuration
|
||||
const K6_CLOUD_TOKEN = process.env.K6_CLOUD_TOKEN || '9347b8bd716cadc243e92f7d2f89107febfb81b49f2340d17da515d7b0513b51';
|
||||
const K6_CLOUD_PROJECT_ID = process.env.K6_CLOUD_PROJECT_ID || '4254406';
|
||||
const PAUSE_BETWEEN_TESTS = 30; // seconds
|
||||
|
||||
/**
|
||||
* Sleep for specified milliseconds
|
||||
*/
|
||||
function sleep(ms) {
|
||||
return new Promise(resolve => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a single k6 test
|
||||
*/
|
||||
async function runTest(test, index) {
|
||||
return new Promise((resolve, reject) => {
|
||||
console.log(`\n🚀 Test ${index + 1}/${unifiedTestScenarios.length}: ${test.name}`);
|
||||
console.log(`📊 Config: ${test.vus} VUs × ${test.duration} (${executionMode} mode)`);
|
||||
console.log(`📁 Script: ${test.file}`);
|
||||
console.log(`📋 Description: ${test.description}`);
|
||||
console.log(`⏱️ Test started: ${new Date().toISOString()}`);
|
||||
|
||||
const env = {
|
||||
K6_CLOUD_TOKEN,
|
||||
K6_CLOUD_PROJECT_ID,
|
||||
K6_ENVIRONMENT: environment,
|
||||
VUS: test.vus.toString(),
|
||||
DURATION: test.duration,
|
||||
RAMP_UP: test.rampUp,
|
||||
RAMP_DOWN: test.rampDown,
|
||||
...test.env
|
||||
};
|
||||
|
||||
let args;
|
||||
if (executionMode === 'cloud') {
|
||||
args = [
|
||||
'cloud', 'run',
|
||||
...Object.entries(env).map(([key, value]) => ['--env', `${key}=${value}`]).flat(),
|
||||
test.file
|
||||
];
|
||||
} else {
|
||||
args = [
|
||||
'run',
|
||||
...Object.entries(env).map(([key, value]) => ['--env', `${key}=${value}`]).flat(),
|
||||
test.file
|
||||
];
|
||||
}
|
||||
|
||||
const k6Process = spawn('k6', args, {
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
env: { ...process.env, ...env }
|
||||
});
|
||||
|
||||
let output = '';
|
||||
let testId = null;
|
||||
|
||||
k6Process.stdout.on('data', (data) => {
|
||||
const str = data.toString();
|
||||
output += str;
|
||||
|
||||
// Extract test ID from k6 cloud output
|
||||
const testIdMatch = str.match(/Test created: .*\/(\d+)/);
|
||||
if (testIdMatch) {
|
||||
testId = testIdMatch[1];
|
||||
console.log(`🔗 Test URL: https://significantgravitas.grafana.net/a/k6-app/runs/${testId}`);
|
||||
}
|
||||
|
||||
// Show progress updates
|
||||
const progressMatch = str.match(/(\d+)%/);
|
||||
if (progressMatch) {
|
||||
process.stdout.write(`\r⏳ Progress: ${progressMatch[1]}%`);
|
||||
}
|
||||
});
|
||||
|
||||
k6Process.stderr.on('data', (data) => {
|
||||
output += data.toString();
|
||||
});
|
||||
|
||||
k6Process.on('close', (code) => {
|
||||
process.stdout.write('\n'); // Clear progress line
|
||||
|
||||
if (code === 0) {
|
||||
console.log(`✅ ${test.name} SUCCESS`);
|
||||
resolve({
|
||||
success: true,
|
||||
testId,
|
||||
url: testId ? `https://significantgravitas.grafana.net/a/k6-app/runs/${testId}` : 'unknown',
|
||||
vus: test.vus,
|
||||
duration: test.duration
|
||||
});
|
||||
} else {
|
||||
console.log(`❌ ${test.name} FAILED (exit code ${code})`);
|
||||
resolve({
|
||||
success: false,
|
||||
testId,
|
||||
url: testId ? `https://significantgravitas.grafana.net/a/k6-app/runs/${testId}` : 'unknown',
|
||||
exitCode: code,
|
||||
vus: test.vus,
|
||||
duration: test.duration
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
k6Process.on('error', (error) => {
|
||||
console.log(`❌ ${test.name} ERROR: ${error.message}`);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Main execution
|
||||
*/
|
||||
async function main() {
|
||||
console.log(`\n📋 UNIFIED TEST PLAN`);
|
||||
console.log(`📊 Total tests: ${unifiedTestScenarios.length} (reduced from 25 original tests)`);
|
||||
console.log(`⏱️ Estimated duration: ~60 minutes\n`);
|
||||
|
||||
console.log(`📋 Test Summary:`);
|
||||
unifiedTestScenarios.forEach((test, i) => {
|
||||
console.log(` ${i + 1}. ${test.name} (${test.vus} VUs × ${test.duration})`);
|
||||
});
|
||||
console.log('');
|
||||
|
||||
const results = [];
|
||||
|
||||
for (let i = 0; i < unifiedTestScenarios.length; i++) {
|
||||
const test = unifiedTestScenarios[i];
|
||||
|
||||
try {
|
||||
const result = await runTest(test, i);
|
||||
results.push({ ...test, ...result });
|
||||
|
||||
// Pause between tests (except after the last one)
|
||||
if (i < unifiedTestScenarios.length - 1) {
|
||||
console.log(`\n⏸️ Pausing ${PAUSE_BETWEEN_TESTS}s before next test...`);
|
||||
await sleep(PAUSE_BETWEEN_TESTS * 1000);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`💥 Fatal error running ${test.name}:`, error.message);
|
||||
results.push({ ...test, success: false, error: error.message });
|
||||
}
|
||||
}
|
||||
|
||||
// Summary
|
||||
console.log('\n' + '='.repeat(60));
|
||||
console.log('🏁 UNIFIED LOAD TEST RESULTS SUMMARY');
|
||||
console.log('='.repeat(60));
|
||||
|
||||
const successful = results.filter(r => r.success);
|
||||
const failed = results.filter(r => !r.success);
|
||||
|
||||
console.log(`✅ Successful tests: ${successful.length}/${results.length} (${Math.round(successful.length / results.length * 100)}%)`);
|
||||
console.log(`❌ Failed tests: ${failed.length}/${results.length}`);
|
||||
|
||||
if (successful.length > 0) {
|
||||
console.log('\n✅ SUCCESSFUL TESTS:');
|
||||
successful.forEach(test => {
|
||||
console.log(` • ${test.name} (${test.vus} VUs) - ${test.url}`);
|
||||
});
|
||||
}
|
||||
|
||||
if (failed.length > 0) {
|
||||
console.log('\n❌ FAILED TESTS:');
|
||||
failed.forEach(test => {
|
||||
console.log(` • ${test.name} (${test.vus} VUs) - ${test.url || 'no URL'} (exit: ${test.exitCode || 'unknown'})`);
|
||||
});
|
||||
}
|
||||
|
||||
// Calculate total VU-minutes tested
|
||||
const totalVuMinutes = results.reduce((sum, test) => {
|
||||
const minutes = parseFloat(test.duration.replace(/[ms]/g, ''));
|
||||
const multiplier = test.duration.includes('m') ? 1 : (1/60); // convert seconds to minutes
|
||||
return sum + (test.vus * minutes * multiplier);
|
||||
}, 0);
|
||||
|
||||
console.log(`\n📊 LOAD TESTING SUMMARY:`);
|
||||
console.log(` • Total VU-minutes tested: ${Math.round(totalVuMinutes)}`);
|
||||
console.log(` • Peak concurrent VUs: ${Math.max(...results.map(r => r.vus))}`);
|
||||
console.log(` • Average test duration: ${(results.reduce((sum, r) => sum + parseFloat(r.duration.replace(/[ms]/g, '')), 0) / results.length).toFixed(1)}${results[0].duration.includes('m') ? 'm' : 's'}`);
|
||||
|
||||
// Write results to file
|
||||
const timestamp = Math.floor(Date.now() / 1000);
|
||||
const resultsFile = `unified-results-${timestamp}.json`;
|
||||
fs.writeFileSync(resultsFile, JSON.stringify(results, null, 2));
|
||||
console.log(`\n📄 Detailed results saved to: ${resultsFile}`);
|
||||
|
||||
console.log(`\n🎉 UNIFIED LOAD TEST ORCHESTRATOR COMPLETE\n`);
|
||||
|
||||
process.exit(failed.length === 0 ? 0 : 1);
|
||||
}
|
||||
|
||||
// Run if called directly
|
||||
if (process.argv[1] === new URL(import.meta.url).pathname) {
|
||||
main().catch(error => {
|
||||
console.error('💥 Fatal error:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
}
|
||||
197
autogpt_platform/backend/load-tests/tests/api/core-api-test.js
Normal file
197
autogpt_platform/backend/load-tests/tests/api/core-api-test.js
Normal file
@@ -0,0 +1,197 @@
|
||||
// Simple API diagnostic test
|
||||
import http from "k6/http";
|
||||
import { check } from "k6";
|
||||
import { getEnvironmentConfig } from "../../configs/environment.js";
|
||||
import { getPreAuthenticatedHeaders } from "../../configs/pre-authenticated-tokens.js";
|
||||
|
||||
const config = getEnvironmentConfig();
|
||||
|
||||
export const options = {
|
||||
stages: [
|
||||
{ duration: __ENV.RAMP_UP || "1m", target: parseInt(__ENV.VUS) || 1 },
|
||||
{ duration: __ENV.DURATION || "5m", target: parseInt(__ENV.VUS) || 1 },
|
||||
{ duration: __ENV.RAMP_DOWN || "1m", target: 0 },
|
||||
],
|
||||
// Thresholds disabled to prevent test abortion - collect all performance data
|
||||
// thresholds: {
|
||||
// checks: ['rate>0.70'],
|
||||
// http_req_duration: ['p(95)<30000'],
|
||||
// http_req_failed: ['rate<0.3'],
|
||||
// },
|
||||
cloud: {
|
||||
projectID: __ENV.K6_CLOUD_PROJECT_ID,
|
||||
name: "AutoGPT Platform - Core API Validation Test",
|
||||
},
|
||||
// Timeout configurations to prevent early termination
|
||||
setupTimeout: "60s",
|
||||
teardownTimeout: "60s",
|
||||
noConnectionReuse: false,
|
||||
userAgent: "k6-load-test/1.0",
|
||||
};
|
||||
|
||||
export default function () {
|
||||
// Get load multiplier - how many concurrent requests each VU should make
|
||||
const requestsPerVU = parseInt(__ENV.REQUESTS_PER_VU) || 1;
|
||||
|
||||
try {
|
||||
// Step 1: Get pre-authenticated headers (no auth API calls during test)
|
||||
const headers = getPreAuthenticatedHeaders(__VU);
|
||||
|
||||
// Handle missing token gracefully
|
||||
if (!headers || !headers.Authorization) {
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} has no valid pre-authenticated token - skipping core API test`,
|
||||
);
|
||||
check(null, {
|
||||
"Core API: Failed gracefully without crashing VU": () => true,
|
||||
});
|
||||
return; // Exit iteration gracefully without crashing
|
||||
}
|
||||
|
||||
console.log(
|
||||
`🚀 VU ${__VU} making ${requestsPerVU} concurrent API requests...`,
|
||||
);
|
||||
|
||||
// Create array of API requests to run concurrently
|
||||
const requests = [];
|
||||
|
||||
for (let i = 0; i < requestsPerVU; i++) {
|
||||
// Add core API requests that represent realistic user workflows
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/api/credits`,
|
||||
params: { headers },
|
||||
});
|
||||
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/api/graphs`,
|
||||
params: { headers },
|
||||
});
|
||||
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/api/blocks`,
|
||||
params: { headers },
|
||||
});
|
||||
}
|
||||
|
||||
// Execute all requests concurrently
|
||||
const responses = http.batch(requests);
|
||||
|
||||
// Validate results
|
||||
let creditsSuccesses = 0;
|
||||
let graphsSuccesses = 0;
|
||||
let blocksSuccesses = 0;
|
||||
|
||||
for (let i = 0; i < responses.length; i++) {
|
||||
const response = responses[i];
|
||||
const apiType = i % 3; // 0=credits, 1=graphs, 2=blocks
|
||||
|
||||
if (apiType === 0) {
|
||||
// Credits API request
|
||||
check(response, {
|
||||
"Credits API: HTTP Status is 200": (r) => r.status === 200,
|
||||
"Credits API: Not Auth Error (401/403)": (r) =>
|
||||
r.status !== 401 && r.status !== 403,
|
||||
"Credits API: Response has valid JSON": (r) => {
|
||||
try {
|
||||
JSON.parse(r.body);
|
||||
return true;
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Credits API: Response has credits field": (r) => {
|
||||
try {
|
||||
const data = JSON.parse(r.body);
|
||||
return data && typeof data.credits === "number";
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Credits API: Overall Success": (r) => {
|
||||
try {
|
||||
if (r.status !== 200) return false;
|
||||
const data = JSON.parse(r.body);
|
||||
return data && typeof data.credits === "number";
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
});
|
||||
} else if (apiType === 1) {
|
||||
// Graphs API request
|
||||
check(response, {
|
||||
"Graphs API: HTTP Status is 200": (r) => r.status === 200,
|
||||
"Graphs API: Not Auth Error (401/403)": (r) =>
|
||||
r.status !== 401 && r.status !== 403,
|
||||
"Graphs API: Response has valid JSON": (r) => {
|
||||
try {
|
||||
JSON.parse(r.body);
|
||||
return true;
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Graphs API: Response is array": (r) => {
|
||||
try {
|
||||
const data = JSON.parse(r.body);
|
||||
return Array.isArray(data);
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Graphs API: Overall Success": (r) => {
|
||||
try {
|
||||
if (r.status !== 200) return false;
|
||||
const data = JSON.parse(r.body);
|
||||
return Array.isArray(data);
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// Blocks API request
|
||||
check(response, {
|
||||
"Blocks API: HTTP Status is 200": (r) => r.status === 200,
|
||||
"Blocks API: Not Auth Error (401/403)": (r) =>
|
||||
r.status !== 401 && r.status !== 403,
|
||||
"Blocks API: Response has valid JSON": (r) => {
|
||||
try {
|
||||
JSON.parse(r.body);
|
||||
return true;
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Blocks API: Response has blocks data": (r) => {
|
||||
try {
|
||||
const data = JSON.parse(r.body);
|
||||
return data && (Array.isArray(data) || typeof data === "object");
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Blocks API: Overall Success": (r) => {
|
||||
try {
|
||||
if (r.status !== 200) return false;
|
||||
const data = JSON.parse(r.body);
|
||||
return data && (Array.isArray(data) || typeof data === "object");
|
||||
} catch (e) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
console.log(
|
||||
`✅ VU ${__VU} completed ${responses.length} API requests with detailed auth/validation tracking`,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(`💥 Test failed: ${error.message}`);
|
||||
console.error(`💥 Stack: ${error.stack}`);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,249 @@
|
||||
// Dedicated graph execution load testing
|
||||
import http from "k6/http";
|
||||
import { check, sleep, group } from "k6";
|
||||
import { Rate, Trend, Counter } from "k6/metrics";
|
||||
import { getEnvironmentConfig } from "../../configs/environment.js";
|
||||
import { getPreAuthenticatedHeaders } from "../../configs/pre-authenticated-tokens.js";
|
||||
// Test data generation functions
|
||||
function generateTestGraph(name = null) {
|
||||
const graphName =
|
||||
name || `Load Test Graph ${Math.random().toString(36).substr(2, 9)}`;
|
||||
return {
|
||||
name: graphName,
|
||||
description: "Generated graph for load testing purposes",
|
||||
graph: {
|
||||
name: graphName,
|
||||
description: "Load testing graph",
|
||||
nodes: [
|
||||
{
|
||||
id: "input_node",
|
||||
name: "Agent Input",
|
||||
block_id: "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
input_default: {
|
||||
name: "Load Test Input",
|
||||
description: "Test input for load testing",
|
||||
placeholder_values: {},
|
||||
},
|
||||
input_nodes: [],
|
||||
output_nodes: ["output_node"],
|
||||
metadata: { position: { x: 100, y: 100 } },
|
||||
},
|
||||
{
|
||||
id: "output_node",
|
||||
name: "Agent Output",
|
||||
block_id: "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
input_default: {
|
||||
name: "Load Test Output",
|
||||
description: "Test output for load testing",
|
||||
value: "Test output value",
|
||||
},
|
||||
input_nodes: ["input_node"],
|
||||
output_nodes: [],
|
||||
metadata: { position: { x: 300, y: 100 } },
|
||||
},
|
||||
],
|
||||
links: [
|
||||
{
|
||||
source_id: "input_node",
|
||||
sink_id: "output_node",
|
||||
source_name: "result",
|
||||
sink_name: "value",
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function generateExecutionInputs() {
|
||||
return {
|
||||
"Load Test Input": {
|
||||
name: "Load Test Input",
|
||||
description: "Test input for load testing",
|
||||
placeholder_values: {
|
||||
test_data: `Test execution at ${new Date().toISOString()}`,
|
||||
test_parameter: Math.random().toString(36).substr(2, 9),
|
||||
numeric_value: Math.floor(Math.random() * 1000),
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const config = getEnvironmentConfig();
|
||||
|
||||
// Custom metrics for graph execution testing
|
||||
const graphCreations = new Counter("graph_creations_total");
|
||||
const graphExecutions = new Counter("graph_executions_total");
|
||||
const graphExecutionTime = new Trend("graph_execution_duration");
|
||||
const graphCreationTime = new Trend("graph_creation_duration");
|
||||
const executionErrors = new Rate("execution_errors");
|
||||
|
||||
// Configurable options for easy load adjustment
|
||||
export const options = {
|
||||
stages: [
|
||||
{ duration: __ENV.RAMP_UP || "1m", target: parseInt(__ENV.VUS) || 5 },
|
||||
{ duration: __ENV.DURATION || "5m", target: parseInt(__ENV.VUS) || 5 },
|
||||
{ duration: __ENV.RAMP_DOWN || "1m", target: 0 },
|
||||
],
|
||||
// Thresholds disabled to prevent test abortion - collect all performance data
|
||||
// thresholds: {
|
||||
// checks: ['rate>0.60'],
|
||||
// http_req_duration: ['p(95)<45000', 'p(99)<60000'],
|
||||
// http_req_failed: ['rate<0.4'],
|
||||
// graph_execution_duration: ['p(95)<45000'],
|
||||
// graph_creation_duration: ['p(95)<30000'],
|
||||
// },
|
||||
cloud: {
|
||||
projectID: __ENV.K6_CLOUD_PROJECT_ID,
|
||||
name: "AutoGPT Platform - Graph Creation & Execution Test",
|
||||
},
|
||||
// Timeout configurations to prevent early termination
|
||||
setupTimeout: "60s",
|
||||
teardownTimeout: "60s",
|
||||
noConnectionReuse: false,
|
||||
userAgent: "k6-load-test/1.0",
|
||||
};
|
||||
|
||||
export function setup() {
|
||||
console.log("🎯 Setting up graph execution load test...");
|
||||
console.log(
|
||||
`Configuration: VUs=${parseInt(__ENV.VUS) || 5}, Duration=${__ENV.DURATION || "2m"}`,
|
||||
);
|
||||
return {
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
export default function (data) {
|
||||
// Get load multiplier - how many concurrent operations each VU should perform
|
||||
const requestsPerVU = parseInt(__ENV.REQUESTS_PER_VU) || 1;
|
||||
|
||||
// Get pre-authenticated headers (no auth API calls during test)
|
||||
const headers = getPreAuthenticatedHeaders(__VU);
|
||||
|
||||
// Handle missing token gracefully
|
||||
if (!headers || !headers.Authorization) {
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} has no valid pre-authenticated token - skipping graph execution`,
|
||||
);
|
||||
check(null, {
|
||||
"Graph Execution: Failed gracefully without crashing VU": () => true,
|
||||
});
|
||||
return; // Exit iteration gracefully without crashing
|
||||
}
|
||||
|
||||
console.log(
|
||||
`🚀 VU ${__VU} performing ${requestsPerVU} concurrent graph operations...`,
|
||||
);
|
||||
|
||||
// Create requests for concurrent execution
|
||||
const graphRequests = [];
|
||||
|
||||
for (let i = 0; i < requestsPerVU; i++) {
|
||||
// Generate graph data
|
||||
const graphData = generateTestGraph();
|
||||
|
||||
// Add graph creation request
|
||||
graphRequests.push({
|
||||
method: "POST",
|
||||
url: `${config.API_BASE_URL}/api/graphs`,
|
||||
body: JSON.stringify(graphData),
|
||||
params: { headers },
|
||||
});
|
||||
}
|
||||
|
||||
// Execute all graph creations concurrently
|
||||
console.log(`📊 Creating ${requestsPerVU} graphs concurrently...`);
|
||||
const responses = http.batch(graphRequests);
|
||||
|
||||
// Process results
|
||||
let successCount = 0;
|
||||
const createdGraphs = [];
|
||||
|
||||
for (let i = 0; i < responses.length; i++) {
|
||||
const response = responses[i];
|
||||
|
||||
const success = check(response, {
|
||||
[`Graph ${i + 1} created successfully`]: (r) => r.status === 200,
|
||||
});
|
||||
|
||||
if (success && response.status === 200) {
|
||||
successCount++;
|
||||
try {
|
||||
const graph = JSON.parse(response.body);
|
||||
createdGraphs.push(graph);
|
||||
graphCreations.add(1);
|
||||
} catch (e) {
|
||||
console.error(`Error parsing graph ${i + 1} response:`, e);
|
||||
}
|
||||
} else {
|
||||
console.log(`❌ Graph ${i + 1} creation failed: ${response.status}`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log(
|
||||
`✅ VU ${__VU} created ${successCount}/${requestsPerVU} graphs concurrently`,
|
||||
);
|
||||
|
||||
// Execute a subset of created graphs (to avoid overloading execution)
|
||||
const graphsToExecute = createdGraphs.slice(
|
||||
0,
|
||||
Math.min(5, createdGraphs.length),
|
||||
);
|
||||
|
||||
if (graphsToExecute.length > 0) {
|
||||
console.log(`⚡ Executing ${graphsToExecute.length} graphs...`);
|
||||
|
||||
const executionRequests = [];
|
||||
|
||||
for (const graph of graphsToExecute) {
|
||||
const executionInputs = generateExecutionInputs();
|
||||
|
||||
executionRequests.push({
|
||||
method: "POST",
|
||||
url: `${config.API_BASE_URL}/api/graphs/${graph.id}/execute/${graph.version}`,
|
||||
body: JSON.stringify({
|
||||
inputs: executionInputs,
|
||||
credentials_inputs: {},
|
||||
}),
|
||||
params: { headers },
|
||||
});
|
||||
}
|
||||
|
||||
// Execute graphs concurrently
|
||||
const executionResponses = http.batch(executionRequests);
|
||||
|
||||
let executionSuccessCount = 0;
|
||||
for (let i = 0; i < executionResponses.length; i++) {
|
||||
const response = executionResponses[i];
|
||||
|
||||
const success = check(response, {
|
||||
[`Graph ${i + 1} execution initiated`]: (r) =>
|
||||
r.status === 200 || r.status === 402,
|
||||
});
|
||||
|
||||
if (success) {
|
||||
executionSuccessCount++;
|
||||
graphExecutions.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
console.log(
|
||||
`✅ VU ${__VU} executed ${executionSuccessCount}/${graphsToExecute.length} graphs`,
|
||||
);
|
||||
}
|
||||
|
||||
// Think time between iterations
|
||||
sleep(Math.random() * 2 + 1); // 1-3 seconds
|
||||
}
|
||||
|
||||
// Legacy functions removed - replaced by concurrent execution in main function
|
||||
// These functions are no longer used since implementing http.batch() for true concurrency
|
||||
|
||||
export function teardown(data) {
|
||||
console.log("🧹 Cleaning up graph execution load test...");
|
||||
console.log(`Total graph creations: ${graphCreations.value || 0}`);
|
||||
console.log(`Total graph executions: ${graphExecutions.value || 0}`);
|
||||
|
||||
const testDuration = Date.now() - data.timestamp;
|
||||
console.log(`Test completed in ${testDuration}ms`);
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
/**
|
||||
* Basic Connectivity Test
|
||||
*
|
||||
* Tests basic connectivity and authentication without requiring backend API access
|
||||
* This test validates that the core infrastructure is working correctly
|
||||
*/
|
||||
|
||||
import http from "k6/http";
|
||||
import { check } from "k6";
|
||||
import { getEnvironmentConfig } from "../../configs/environment.js";
|
||||
import { getPreAuthenticatedHeaders } from "../../configs/pre-authenticated-tokens.js";
|
||||
|
||||
const config = getEnvironmentConfig();
|
||||
|
||||
export const options = {
|
||||
stages: [
|
||||
{ duration: __ENV.RAMP_UP || "1m", target: parseInt(__ENV.VUS) || 1 },
|
||||
{ duration: __ENV.DURATION || "5m", target: parseInt(__ENV.VUS) || 1 },
|
||||
{ duration: __ENV.RAMP_DOWN || "1m", target: 0 },
|
||||
],
|
||||
thresholds: {
|
||||
checks: ["rate>0.70"], // Reduced from 0.85 due to auth timeouts under load
|
||||
http_req_duration: ["p(95)<30000"], // Increased for cloud testing with high concurrency
|
||||
http_req_failed: ["rate<0.6"], // Increased to account for auth timeouts
|
||||
},
|
||||
cloud: {
|
||||
projectID: __ENV.K6_CLOUD_PROJECT_ID,
|
||||
name: "AutoGPT Platform - Basic Connectivity & Auth Test",
|
||||
},
|
||||
// Timeout configurations to prevent early termination
|
||||
setupTimeout: "60s",
|
||||
teardownTimeout: "60s",
|
||||
noConnectionReuse: false,
|
||||
userAgent: "k6-load-test/1.0",
|
||||
};
|
||||
|
||||
export default function () {
|
||||
// Get load multiplier - how many concurrent requests each VU should make
|
||||
const requestsPerVU = parseInt(__ENV.REQUESTS_PER_VU) || 1;
|
||||
|
||||
try {
|
||||
// Get pre-authenticated headers
|
||||
const headers = getPreAuthenticatedHeaders(__VU);
|
||||
|
||||
// Handle authentication failure gracefully
|
||||
if (!headers || !headers.Authorization) {
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} has no valid pre-authentication token - skipping iteration`,
|
||||
);
|
||||
check(null, {
|
||||
"Authentication: Failed gracefully without crashing VU": () => true,
|
||||
});
|
||||
return; // Exit iteration gracefully without crashing
|
||||
}
|
||||
|
||||
console.log(`🚀 VU ${__VU} making ${requestsPerVU} concurrent requests...`);
|
||||
|
||||
// Create array of request functions to run concurrently
|
||||
const requests = [];
|
||||
|
||||
for (let i = 0; i < requestsPerVU; i++) {
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.SUPABASE_URL}/rest/v1/`,
|
||||
params: { headers: { apikey: config.SUPABASE_ANON_KEY } },
|
||||
});
|
||||
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/health`,
|
||||
params: { headers },
|
||||
});
|
||||
}
|
||||
|
||||
// Execute all requests concurrently
|
||||
const responses = http.batch(requests);
|
||||
|
||||
// Validate results
|
||||
let supabaseSuccesses = 0;
|
||||
let backendSuccesses = 0;
|
||||
|
||||
for (let i = 0; i < responses.length; i++) {
|
||||
const response = responses[i];
|
||||
|
||||
if (i % 2 === 0) {
|
||||
// Supabase request
|
||||
const connectivityCheck = check(response, {
|
||||
"Supabase connectivity: Status is not 500": (r) => r.status !== 500,
|
||||
"Supabase connectivity: Response time < 5s": (r) =>
|
||||
r.timings.duration < 5000,
|
||||
});
|
||||
if (connectivityCheck) supabaseSuccesses++;
|
||||
} else {
|
||||
// Backend request
|
||||
const backendCheck = check(response, {
|
||||
"Backend server: Responds (any status)": (r) => r.status > 0,
|
||||
"Backend server: Response time < 5s": (r) =>
|
||||
r.timings.duration < 5000,
|
||||
});
|
||||
if (backendCheck) backendSuccesses++;
|
||||
}
|
||||
}
|
||||
|
||||
console.log(
|
||||
`✅ VU ${__VU} completed: ${supabaseSuccesses}/${requestsPerVU} Supabase, ${backendSuccesses}/${requestsPerVU} backend requests successful`,
|
||||
);
|
||||
|
||||
// Basic auth validation (once per iteration)
|
||||
const authCheck = check(headers, {
|
||||
"Authentication: Pre-auth token available": (h) =>
|
||||
h && h.Authorization && h.Authorization.length > 0,
|
||||
});
|
||||
|
||||
// JWT structure validation (once per iteration)
|
||||
const token = headers.Authorization.replace("Bearer ", "");
|
||||
const tokenParts = token.split(".");
|
||||
const tokenStructureCheck = check(tokenParts, {
|
||||
"JWT token: Has 3 parts (header.payload.signature)": (parts) =>
|
||||
parts.length === 3,
|
||||
"JWT token: Header is base64": (parts) =>
|
||||
parts[0] && parts[0].length > 10,
|
||||
"JWT token: Payload is base64": (parts) =>
|
||||
parts[1] && parts[1].length > 50,
|
||||
"JWT token: Signature exists": (parts) =>
|
||||
parts[2] && parts[2].length > 10,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`💥 Test failed: ${error.message}`);
|
||||
check(null, {
|
||||
"Test execution: No errors": () => false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export function teardown(data) {
|
||||
console.log(`🏁 Basic connectivity test completed`);
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
// Test individual API endpoints to isolate performance bottlenecks
|
||||
import http from "k6/http";
|
||||
import { check } from "k6";
|
||||
import { getEnvironmentConfig } from "../../configs/environment.js";
|
||||
import { getPreAuthenticatedHeaders } from "../../configs/pre-authenticated-tokens.js";
|
||||
|
||||
const config = getEnvironmentConfig();
|
||||
|
||||
export const options = {
|
||||
stages: [
|
||||
{ duration: __ENV.RAMP_UP || "10s", target: parseInt(__ENV.VUS) || 3 },
|
||||
{ duration: __ENV.DURATION || "20s", target: parseInt(__ENV.VUS) || 3 },
|
||||
{ duration: __ENV.RAMP_DOWN || "10s", target: 0 },
|
||||
],
|
||||
thresholds: {
|
||||
checks: ["rate>0.50"], // 50% success rate (was 70%)
|
||||
http_req_duration: ["p(95)<60000"], // P95 under 60s (was 5s)
|
||||
http_req_failed: ["rate<0.5"], // 50% failure rate allowed (was 30%)
|
||||
},
|
||||
cloud: {
|
||||
projectID: parseInt(__ENV.K6_CLOUD_PROJECT_ID) || 4254406,
|
||||
name: `AutoGPT Single Endpoint Test - ${__ENV.ENDPOINT || "credits"} API`,
|
||||
},
|
||||
};
|
||||
|
||||
export default function () {
|
||||
const endpoint = __ENV.ENDPOINT || "credits"; // credits, graphs, blocks, executions
|
||||
const concurrentRequests = parseInt(__ENV.CONCURRENT_REQUESTS) || 1;
|
||||
|
||||
try {
|
||||
const headers = getPreAuthenticatedHeaders(__VU);
|
||||
|
||||
if (!headers || !headers.Authorization) {
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} has no valid pre-authentication token - skipping test`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(
|
||||
`🚀 VU ${__VU} testing /api/${endpoint} with ${concurrentRequests} concurrent requests`,
|
||||
);
|
||||
|
||||
if (concurrentRequests === 1) {
|
||||
// Single request mode (original behavior)
|
||||
const response = http.get(`${config.API_BASE_URL}/api/${endpoint}`, {
|
||||
headers,
|
||||
});
|
||||
|
||||
const success = check(response, {
|
||||
[`${endpoint} API: Status is 200`]: (r) => r.status === 200,
|
||||
[`${endpoint} API: Response time < 3s`]: (r) =>
|
||||
r.timings.duration < 3000,
|
||||
});
|
||||
|
||||
if (success) {
|
||||
console.log(
|
||||
`✅ VU ${__VU} /api/${endpoint} successful: ${response.timings.duration}ms`,
|
||||
);
|
||||
} else {
|
||||
console.log(
|
||||
`❌ VU ${__VU} /api/${endpoint} failed: ${response.status}, ${response.timings.duration}ms`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Concurrent requests mode using http.batch()
|
||||
const requests = [];
|
||||
for (let i = 0; i < concurrentRequests; i++) {
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/api/${endpoint}`,
|
||||
params: { headers },
|
||||
});
|
||||
}
|
||||
|
||||
const responses = http.batch(requests);
|
||||
|
||||
let successCount = 0;
|
||||
let totalTime = 0;
|
||||
|
||||
for (let i = 0; i < responses.length; i++) {
|
||||
const response = responses[i];
|
||||
const success = check(response, {
|
||||
[`${endpoint} API Request ${i + 1}: Status is 200`]: (r) =>
|
||||
r.status === 200,
|
||||
[`${endpoint} API Request ${i + 1}: Response time < 5s`]: (r) =>
|
||||
r.timings.duration < 5000,
|
||||
});
|
||||
|
||||
if (success) {
|
||||
successCount++;
|
||||
}
|
||||
totalTime += response.timings.duration;
|
||||
}
|
||||
|
||||
const avgTime = totalTime / responses.length;
|
||||
console.log(
|
||||
`✅ VU ${__VU} /api/${endpoint}: ${successCount}/${concurrentRequests} successful, avg: ${avgTime.toFixed(0)}ms`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`💥 VU ${__VU} error: ${error.message}`);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,508 @@
|
||||
import http from "k6/http";
|
||||
import { check, sleep, group } from "k6";
|
||||
import { Rate, Trend, Counter } from "k6/metrics";
|
||||
import {
|
||||
getEnvironmentConfig,
|
||||
PERFORMANCE_CONFIG,
|
||||
} from "../../configs/environment.js";
|
||||
import { getPreAuthenticatedHeaders } from "../../configs/pre-authenticated-tokens.js";
|
||||
|
||||
// Inline test data generators (simplified from utils/test-data.js)
|
||||
function generateTestGraph(name = null) {
|
||||
const graphName =
|
||||
name || `Load Test Graph ${Math.random().toString(36).substr(2, 9)}`;
|
||||
return {
|
||||
name: graphName,
|
||||
description: "Generated graph for load testing purposes",
|
||||
graph: {
|
||||
nodes: [],
|
||||
links: [],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function generateExecutionInputs() {
|
||||
return { test_input: "load_test_value" };
|
||||
}
|
||||
|
||||
function generateScheduleData() {
|
||||
return { enabled: false };
|
||||
}
|
||||
|
||||
function generateAPIKeyRequest() {
|
||||
return { name: "Load Test API Key" };
|
||||
}
|
||||
|
||||
const config = getEnvironmentConfig();
|
||||
|
||||
// Custom metrics
|
||||
const userOperations = new Counter("user_operations_total");
|
||||
const graphOperations = new Counter("graph_operations_total");
|
||||
const executionOperations = new Counter("execution_operations_total");
|
||||
const apiResponseTime = new Trend("api_response_time");
|
||||
const authErrors = new Rate("auth_errors");
|
||||
|
||||
// Test configuration for normal load testing
|
||||
export const options = {
|
||||
stages: [
|
||||
{
|
||||
duration: __ENV.RAMP_UP || "1m",
|
||||
target: parseInt(__ENV.VUS) || PERFORMANCE_CONFIG.DEFAULT_VUS,
|
||||
},
|
||||
{
|
||||
duration: __ENV.DURATION || "5m",
|
||||
target: parseInt(__ENV.VUS) || PERFORMANCE_CONFIG.DEFAULT_VUS,
|
||||
},
|
||||
{ duration: __ENV.RAMP_DOWN || "1m", target: 0 },
|
||||
],
|
||||
// maxDuration: '15m', // Removed - not supported in k6 cloud
|
||||
thresholds: {
|
||||
checks: ["rate>0.50"], // Reduced for high concurrency complex operations
|
||||
http_req_duration: ["p(95)<60000", "p(99)<60000"], // Allow up to 60s response times
|
||||
http_req_failed: ["rate<0.5"], // Allow 50% failure rate for stress testing
|
||||
},
|
||||
cloud: {
|
||||
projectID: __ENV.K6_CLOUD_PROJECT_ID,
|
||||
name: "AutoGPT Platform - Full Platform Integration Test",
|
||||
},
|
||||
// Timeout configurations to prevent early termination
|
||||
setupTimeout: "60s",
|
||||
teardownTimeout: "60s",
|
||||
noConnectionReuse: false,
|
||||
userAgent: "k6-load-test/1.0",
|
||||
};
|
||||
|
||||
export function setup() {
|
||||
console.log("🎯 Setting up load test scenario...");
|
||||
return {
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
export default function (data) {
|
||||
// Get load multiplier - how many concurrent user journeys each VU should simulate
|
||||
const requestsPerVU = parseInt(__ENV.REQUESTS_PER_VU) || 1;
|
||||
|
||||
let headers;
|
||||
|
||||
try {
|
||||
headers = getPreAuthenticatedHeaders(__VU);
|
||||
} catch (error) {
|
||||
console.error(`❌ Authentication failed:`, error);
|
||||
authErrors.add(1);
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle authentication failure gracefully
|
||||
if (!headers || !headers.Authorization) {
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} has no valid pre-authentication token - skipping comprehensive platform test`,
|
||||
);
|
||||
check(null, {
|
||||
"Comprehensive Platform: Failed gracefully without crashing VU": () =>
|
||||
true,
|
||||
});
|
||||
return; // Exit iteration gracefully without crashing
|
||||
}
|
||||
|
||||
console.log(
|
||||
`🚀 VU ${__VU} simulating ${requestsPerVU} realistic user workflows...`,
|
||||
);
|
||||
|
||||
// Create concurrent requests for all user journeys
|
||||
const requests = [];
|
||||
|
||||
// Simulate realistic user workflows instead of just API hammering
|
||||
for (let i = 0; i < requestsPerVU; i++) {
|
||||
// Workflow 1: User checking their dashboard
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/api/credits`,
|
||||
params: { headers },
|
||||
});
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/api/graphs`,
|
||||
params: { headers },
|
||||
});
|
||||
|
||||
// Workflow 2: User exploring available blocks for building agents
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/api/blocks`,
|
||||
params: { headers },
|
||||
});
|
||||
|
||||
// Workflow 3: User monitoring their recent executions
|
||||
requests.push({
|
||||
method: "GET",
|
||||
url: `${config.API_BASE_URL}/api/executions`,
|
||||
params: { headers },
|
||||
});
|
||||
}
|
||||
|
||||
console.log(
|
||||
`📊 Executing ${requests.length} requests across realistic user workflows...`,
|
||||
);
|
||||
|
||||
// Execute all requests concurrently
|
||||
const responses = http.batch(requests);
|
||||
|
||||
// Process results and count successes
|
||||
let creditsSuccesses = 0,
|
||||
graphsSuccesses = 0,
|
||||
blocksSuccesses = 0,
|
||||
executionsSuccesses = 0;
|
||||
|
||||
for (let i = 0; i < responses.length; i++) {
|
||||
const response = responses[i];
|
||||
const operationType = i % 4; // Each set of 4 requests: 0=credits, 1=graphs, 2=blocks, 3=executions
|
||||
|
||||
switch (operationType) {
|
||||
case 0: // Dashboard: Check credits
|
||||
if (
|
||||
check(response, {
|
||||
"Dashboard: User credits loaded successfully": (r) =>
|
||||
r.status === 200,
|
||||
})
|
||||
) {
|
||||
creditsSuccesses++;
|
||||
userOperations.add(1);
|
||||
}
|
||||
break;
|
||||
case 1: // Dashboard: View graphs
|
||||
if (
|
||||
check(response, {
|
||||
"Dashboard: User graphs loaded successfully": (r) =>
|
||||
r.status === 200,
|
||||
})
|
||||
) {
|
||||
graphsSuccesses++;
|
||||
graphOperations.add(1);
|
||||
}
|
||||
break;
|
||||
case 2: // Exploration: Browse available blocks
|
||||
if (
|
||||
check(response, {
|
||||
"Block Explorer: Available blocks loaded successfully": (r) =>
|
||||
r.status === 200,
|
||||
})
|
||||
) {
|
||||
blocksSuccesses++;
|
||||
userOperations.add(1);
|
||||
}
|
||||
break;
|
||||
case 3: // Monitoring: Check execution history
|
||||
if (
|
||||
check(response, {
|
||||
"Execution Monitor: Recent executions loaded successfully": (r) =>
|
||||
r.status === 200,
|
||||
})
|
||||
) {
|
||||
executionsSuccesses++;
|
||||
userOperations.add(1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
console.log(
|
||||
`✅ VU ${__VU} completed realistic workflows: ${creditsSuccesses} dashboard checks, ${graphsSuccesses} graph views, ${blocksSuccesses} block explorations, ${executionsSuccesses} execution monitors`,
|
||||
);
|
||||
|
||||
// Think time between user sessions
|
||||
sleep(Math.random() * 3 + 1); // 1-4 seconds
|
||||
}
|
||||
|
||||
function userProfileJourney(headers) {
|
||||
const startTime = Date.now();
|
||||
|
||||
// 1. Get user credits (JWT-only endpoint)
|
||||
const creditsResponse = http.get(`${config.API_BASE_URL}/api/credits`, {
|
||||
headers,
|
||||
});
|
||||
|
||||
userOperations.add(1);
|
||||
|
||||
check(creditsResponse, {
|
||||
"User credits loaded successfully": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
// 2. Check onboarding status
|
||||
const onboardingResponse = http.get(`${config.API_BASE_URL}/api/onboarding`, {
|
||||
headers,
|
||||
});
|
||||
|
||||
userOperations.add(1);
|
||||
|
||||
check(onboardingResponse, {
|
||||
"Onboarding status loaded": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
apiResponseTime.add(Date.now() - startTime);
|
||||
}
|
||||
|
||||
function graphManagementJourney(headers) {
|
||||
const startTime = Date.now();
|
||||
|
||||
// 1. List existing graphs
|
||||
const listResponse = http.get(`${config.API_BASE_URL}/api/graphs`, {
|
||||
headers,
|
||||
});
|
||||
|
||||
graphOperations.add(1);
|
||||
|
||||
const listSuccess = check(listResponse, {
|
||||
"Graphs list loaded successfully": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
// 2. Create a new graph (20% of users)
|
||||
if (Math.random() < 0.2) {
|
||||
const graphData = generateTestGraph();
|
||||
|
||||
const createResponse = http.post(
|
||||
`${config.API_BASE_URL}/api/graphs`,
|
||||
JSON.stringify(graphData),
|
||||
{ headers },
|
||||
);
|
||||
|
||||
graphOperations.add(1);
|
||||
|
||||
const createSuccess = check(createResponse, {
|
||||
"Graph created successfully": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
if (createSuccess && createResponse.status === 200) {
|
||||
try {
|
||||
const createdGraph = JSON.parse(createResponse.body);
|
||||
|
||||
// 3. Get the created graph details
|
||||
const getResponse = http.get(
|
||||
`${config.API_BASE_URL}/api/graphs/${createdGraph.id}`,
|
||||
{ headers },
|
||||
);
|
||||
|
||||
graphOperations.add(1);
|
||||
|
||||
check(getResponse, {
|
||||
"Graph details loaded": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
// 4. Execute the graph (50% chance)
|
||||
if (Math.random() < 0.5) {
|
||||
executeGraphScenario(createdGraph, headers);
|
||||
}
|
||||
|
||||
// 5. Create schedule for graph (10% chance)
|
||||
if (Math.random() < 0.1) {
|
||||
createScheduleScenario(createdGraph.id, headers);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error handling created graph:", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Work with existing graphs (if any)
|
||||
if (listSuccess && listResponse.status === 200) {
|
||||
try {
|
||||
const existingGraphs = JSON.parse(listResponse.body);
|
||||
|
||||
if (existingGraphs.length > 0) {
|
||||
// Pick a random existing graph
|
||||
const randomGraph =
|
||||
existingGraphs[Math.floor(Math.random() * existingGraphs.length)];
|
||||
|
||||
// Get graph details
|
||||
const getResponse = http.get(
|
||||
`${config.API_BASE_URL}/api/graphs/${randomGraph.id}`,
|
||||
{ headers },
|
||||
);
|
||||
|
||||
graphOperations.add(1);
|
||||
|
||||
check(getResponse, {
|
||||
"Existing graph details loaded": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
// Execute existing graph (30% chance)
|
||||
if (Math.random() < 0.3) {
|
||||
executeGraphScenario(randomGraph, headers);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error working with existing graphs:", error);
|
||||
}
|
||||
}
|
||||
|
||||
apiResponseTime.add(Date.now() - startTime);
|
||||
}
|
||||
|
||||
function executeGraphScenario(graph, headers) {
|
||||
const startTime = Date.now();
|
||||
|
||||
const executionInputs = generateExecutionInputs();
|
||||
|
||||
const executeResponse = http.post(
|
||||
`${config.API_BASE_URL}/api/graphs/${graph.id}/execute/${graph.version}`,
|
||||
JSON.stringify({
|
||||
inputs: executionInputs,
|
||||
credentials_inputs: {},
|
||||
}),
|
||||
{ headers },
|
||||
);
|
||||
|
||||
executionOperations.add(1);
|
||||
|
||||
const executeSuccess = check(executeResponse, {
|
||||
"Graph execution initiated": (r) => r.status === 200 || r.status === 402, // 402 = insufficient credits
|
||||
});
|
||||
|
||||
if (executeSuccess && executeResponse.status === 200) {
|
||||
try {
|
||||
const execution = JSON.parse(executeResponse.body);
|
||||
|
||||
// Monitor execution status (simulate user checking results)
|
||||
// Note: setTimeout doesn't work in k6, so we'll check status immediately
|
||||
const statusResponse = http.get(
|
||||
`${config.API_BASE_URL}/api/graphs/${graph.id}/executions/${execution.id}`,
|
||||
{ headers },
|
||||
);
|
||||
|
||||
executionOperations.add(1);
|
||||
|
||||
check(statusResponse, {
|
||||
"Execution status retrieved": (r) => r.status === 200,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error monitoring execution:", error);
|
||||
}
|
||||
}
|
||||
|
||||
apiResponseTime.add(Date.now() - startTime);
|
||||
}
|
||||
|
||||
function createScheduleScenario(graphId, headers) {
|
||||
const scheduleData = generateScheduleData(graphId);
|
||||
|
||||
const scheduleResponse = http.post(
|
||||
`${config.API_BASE_URL}/api/graphs/${graphId}/schedules`,
|
||||
JSON.stringify(scheduleData),
|
||||
{ headers },
|
||||
);
|
||||
|
||||
graphOperations.add(1);
|
||||
|
||||
check(scheduleResponse, {
|
||||
"Schedule created successfully": (r) => r.status === 200,
|
||||
});
|
||||
}
|
||||
|
||||
function blockOperationsJourney(headers) {
|
||||
const startTime = Date.now();
|
||||
|
||||
// 1. Get available blocks
|
||||
const blocksResponse = http.get(`${config.API_BASE_URL}/api/blocks`, {
|
||||
headers,
|
||||
});
|
||||
|
||||
userOperations.add(1);
|
||||
|
||||
const blocksSuccess = check(blocksResponse, {
|
||||
"Blocks list loaded": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
// 2. Execute some blocks directly (simulate testing)
|
||||
if (blocksSuccess && Math.random() < 0.3) {
|
||||
// Execute GetCurrentTimeBlock (simple, fast block)
|
||||
const timeBlockResponse = http.post(
|
||||
`${config.API_BASE_URL}/api/blocks/a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa/execute`,
|
||||
JSON.stringify({
|
||||
trigger: "test",
|
||||
format_type: {
|
||||
discriminator: "iso8601",
|
||||
timezone: "UTC",
|
||||
},
|
||||
}),
|
||||
{ headers },
|
||||
);
|
||||
|
||||
userOperations.add(1);
|
||||
|
||||
check(timeBlockResponse, {
|
||||
"Time block executed or handled gracefully": (r) =>
|
||||
r.status === 200 || r.status === 500, // 500 = user_context missing (expected)
|
||||
});
|
||||
}
|
||||
|
||||
apiResponseTime.add(Date.now() - startTime);
|
||||
}
|
||||
|
||||
function systemOperationsJourney(headers) {
|
||||
const startTime = Date.now();
|
||||
|
||||
// 1. Check executions list (simulate monitoring)
|
||||
const executionsResponse = http.get(`${config.API_BASE_URL}/api/executions`, {
|
||||
headers,
|
||||
});
|
||||
|
||||
userOperations.add(1);
|
||||
|
||||
check(executionsResponse, {
|
||||
"Executions list loaded": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
// 2. Check schedules (if any)
|
||||
const schedulesResponse = http.get(`${config.API_BASE_URL}/api/schedules`, {
|
||||
headers,
|
||||
});
|
||||
|
||||
userOperations.add(1);
|
||||
|
||||
check(schedulesResponse, {
|
||||
"Schedules list loaded": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
// 3. Check API keys (simulate user managing access)
|
||||
if (Math.random() < 0.1) {
|
||||
// 10% of users check API keys
|
||||
const apiKeysResponse = http.get(`${config.API_BASE_URL}/api/api-keys`, {
|
||||
headers,
|
||||
});
|
||||
|
||||
userOperations.add(1);
|
||||
|
||||
check(apiKeysResponse, {
|
||||
"API keys list loaded": (r) => r.status === 200,
|
||||
});
|
||||
|
||||
// Occasionally create new API key (5% chance)
|
||||
if (Math.random() < 0.05) {
|
||||
const keyData = generateAPIKeyRequest();
|
||||
|
||||
const createKeyResponse = http.post(
|
||||
`${config.API_BASE_URL}/api/api-keys`,
|
||||
JSON.stringify(keyData),
|
||||
{ headers },
|
||||
);
|
||||
|
||||
userOperations.add(1);
|
||||
|
||||
check(createKeyResponse, {
|
||||
"API key created successfully": (r) => r.status === 200,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
apiResponseTime.add(Date.now() - startTime);
|
||||
}
|
||||
|
||||
export function teardown(data) {
|
||||
console.log("🧹 Cleaning up load test...");
|
||||
console.log(`Total user operations: ${userOperations.value}`);
|
||||
console.log(`Total graph operations: ${graphOperations.value}`);
|
||||
console.log(`Total execution operations: ${executionOperations.value}`);
|
||||
|
||||
const testDuration = Date.now() - data.timestamp;
|
||||
console.log(`Test completed in ${testDuration}ms`);
|
||||
}
|
||||
@@ -0,0 +1,536 @@
|
||||
import { check } from "k6";
|
||||
import http from "k6/http";
|
||||
import { Counter } from "k6/metrics";
|
||||
|
||||
import { getEnvironmentConfig } from "../../configs/environment.js";
|
||||
import { getPreAuthenticatedHeaders } from "../../configs/pre-authenticated-tokens.js";
|
||||
|
||||
const config = getEnvironmentConfig();
|
||||
const BASE_URL = config.API_BASE_URL;
|
||||
|
||||
// Custom metrics
|
||||
const libraryRequests = new Counter("library_requests_total");
|
||||
const successfulRequests = new Counter("successful_requests_total");
|
||||
const failedRequests = new Counter("failed_requests_total");
|
||||
const authenticationAttempts = new Counter("authentication_attempts_total");
|
||||
const authenticationSuccesses = new Counter("authentication_successes_total");
|
||||
|
||||
// Test configuration
|
||||
const VUS = parseInt(__ENV.VUS) || 5;
|
||||
const DURATION = __ENV.DURATION || "2m";
|
||||
const RAMP_UP = __ENV.RAMP_UP || "30s";
|
||||
const RAMP_DOWN = __ENV.RAMP_DOWN || "30s";
|
||||
const REQUESTS_PER_VU = parseInt(__ENV.REQUESTS_PER_VU) || 5;
|
||||
|
||||
// Performance thresholds for authenticated endpoints
|
||||
const THRESHOLD_P95 = parseInt(__ENV.THRESHOLD_P95) || 10000; // 10s for authenticated endpoints
|
||||
const THRESHOLD_P99 = parseInt(__ENV.THRESHOLD_P99) || 20000; // 20s for authenticated endpoints
|
||||
const THRESHOLD_ERROR_RATE = parseFloat(__ENV.THRESHOLD_ERROR_RATE) || 0.1; // 10% error rate
|
||||
const THRESHOLD_CHECK_RATE = parseFloat(__ENV.THRESHOLD_CHECK_RATE) || 0.85; // 85% success rate
|
||||
|
||||
export const options = {
|
||||
stages: [
|
||||
{ duration: RAMP_UP, target: VUS },
|
||||
{ duration: DURATION, target: VUS },
|
||||
{ duration: RAMP_DOWN, target: 0 },
|
||||
],
|
||||
thresholds: {
|
||||
http_req_duration: [
|
||||
{ threshold: `p(95)<${THRESHOLD_P95}`, abortOnFail: false },
|
||||
{ threshold: `p(99)<${THRESHOLD_P99}`, abortOnFail: false },
|
||||
],
|
||||
http_req_failed: [
|
||||
{ threshold: `rate<${THRESHOLD_ERROR_RATE}`, abortOnFail: false },
|
||||
],
|
||||
checks: [{ threshold: `rate>${THRESHOLD_CHECK_RATE}`, abortOnFail: false }],
|
||||
},
|
||||
tags: {
|
||||
test_type: "marketplace_library_authorized",
|
||||
environment: __ENV.K6_ENVIRONMENT || "DEV",
|
||||
},
|
||||
};
|
||||
|
||||
export default function () {
|
||||
console.log(`📚 VU ${__VU} starting authenticated library journey...`);
|
||||
|
||||
// Get pre-authenticated headers
|
||||
const headers = getPreAuthenticatedHeaders(__VU);
|
||||
if (!headers || !headers.Authorization) {
|
||||
console.log(`❌ VU ${__VU} authentication failed, skipping iteration`);
|
||||
authenticationAttempts.add(1);
|
||||
return;
|
||||
}
|
||||
|
||||
authenticationAttempts.add(1);
|
||||
authenticationSuccesses.add(1);
|
||||
|
||||
// Run multiple library operations per iteration
|
||||
for (let i = 0; i < REQUESTS_PER_VU; i++) {
|
||||
console.log(
|
||||
`🔄 VU ${__VU} starting library operation ${i + 1}/${REQUESTS_PER_VU}...`,
|
||||
);
|
||||
authenticatedLibraryJourney(headers);
|
||||
}
|
||||
}
|
||||
|
||||
function authenticatedLibraryJourney(headers) {
|
||||
const journeyStart = Date.now();
|
||||
|
||||
// Step 1: Get user's library agents
|
||||
console.log(`📖 VU ${__VU} fetching user library agents...`);
|
||||
const libraryAgentsResponse = http.get(
|
||||
`${BASE_URL}/api/library/agents?page=1&page_size=20`,
|
||||
{ headers },
|
||||
);
|
||||
|
||||
libraryRequests.add(1);
|
||||
const librarySuccess = check(libraryAgentsResponse, {
|
||||
"Library agents endpoint returns 200": (r) => r.status === 200,
|
||||
"Library agents response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.agents && Array.isArray(json.agents);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Library agents response time < 10s": (r) => r.timings.duration < 10000,
|
||||
});
|
||||
|
||||
if (librarySuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} library agents request failed: ${libraryAgentsResponse.status} - ${libraryAgentsResponse.body}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Step 2: Get favorite agents
|
||||
console.log(`⭐ VU ${__VU} fetching favorite library agents...`);
|
||||
const favoriteAgentsResponse = http.get(
|
||||
`${BASE_URL}/api/library/agents/favorites?page=1&page_size=10`,
|
||||
{ headers },
|
||||
);
|
||||
|
||||
libraryRequests.add(1);
|
||||
const favoritesSuccess = check(favoriteAgentsResponse, {
|
||||
"Favorite agents endpoint returns 200": (r) => r.status === 200,
|
||||
"Favorite agents response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.agents !== undefined && Array.isArray(json.agents);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Favorite agents response time < 10s": (r) => r.timings.duration < 10000,
|
||||
});
|
||||
|
||||
if (favoritesSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} favorite agents request failed: ${favoriteAgentsResponse.status}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Step 3: Add marketplace agent to library (simulate discovering and adding an agent)
|
||||
console.log(`🛍️ VU ${__VU} browsing marketplace to add agent...`);
|
||||
|
||||
// First get available store agents to find one to add
|
||||
const storeAgentsResponse = http.get(
|
||||
`${BASE_URL}/api/store/agents?page=1&page_size=5`,
|
||||
);
|
||||
|
||||
libraryRequests.add(1);
|
||||
const storeAgentsSuccess = check(storeAgentsResponse, {
|
||||
"Store agents endpoint returns 200": (r) => r.status === 200,
|
||||
"Store agents response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return (
|
||||
json &&
|
||||
json.agents &&
|
||||
Array.isArray(json.agents) &&
|
||||
json.agents.length > 0
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
if (storeAgentsSuccess) {
|
||||
successfulRequests.add(1);
|
||||
|
||||
try {
|
||||
const storeAgentsJson = storeAgentsResponse.json();
|
||||
if (storeAgentsJson?.agents && storeAgentsJson.agents.length > 0) {
|
||||
const randomStoreAgent =
|
||||
storeAgentsJson.agents[
|
||||
Math.floor(Math.random() * storeAgentsJson.agents.length)
|
||||
];
|
||||
|
||||
if (randomStoreAgent?.store_listing_version_id) {
|
||||
console.log(
|
||||
`➕ VU ${__VU} adding agent "${randomStoreAgent.name || "Unknown"}" to library...`,
|
||||
);
|
||||
|
||||
const addAgentPayload = {
|
||||
store_listing_version_id: randomStoreAgent.store_listing_version_id,
|
||||
};
|
||||
|
||||
const addAgentResponse = http.post(
|
||||
`${BASE_URL}/api/library/agents`,
|
||||
JSON.stringify(addAgentPayload),
|
||||
{ headers },
|
||||
);
|
||||
|
||||
libraryRequests.add(1);
|
||||
const addAgentSuccess = check(addAgentResponse, {
|
||||
"Add agent returns 201 or 200 (created/already exists)": (r) =>
|
||||
r.status === 201 || r.status === 200,
|
||||
"Add agent response has id": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.id;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Add agent response time < 15s": (r) => r.timings.duration < 15000,
|
||||
});
|
||||
|
||||
if (addAgentSuccess) {
|
||||
successfulRequests.add(1);
|
||||
|
||||
// Step 4: Update the added agent (mark as favorite)
|
||||
try {
|
||||
const addedAgentJson = addAgentResponse.json();
|
||||
if (addedAgentJson?.id) {
|
||||
console.log(`⭐ VU ${__VU} marking agent as favorite...`);
|
||||
|
||||
const updatePayload = {
|
||||
is_favorite: true,
|
||||
auto_update_version: true,
|
||||
};
|
||||
|
||||
const updateAgentResponse = http.patch(
|
||||
`${BASE_URL}/api/library/agents/${addedAgentJson.id}`,
|
||||
JSON.stringify(updatePayload),
|
||||
{ headers },
|
||||
);
|
||||
|
||||
libraryRequests.add(1);
|
||||
const updateSuccess = check(updateAgentResponse, {
|
||||
"Update agent returns 200": (r) => r.status === 200,
|
||||
"Update agent response has updated data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.id && json.is_favorite === true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Update agent response time < 10s": (r) =>
|
||||
r.timings.duration < 10000,
|
||||
});
|
||||
|
||||
if (updateSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} update agent failed: ${updateAgentResponse.status}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Step 5: Get specific library agent details
|
||||
console.log(`📄 VU ${__VU} fetching agent details...`);
|
||||
const agentDetailsResponse = http.get(
|
||||
`${BASE_URL}/api/library/agents/${addedAgentJson.id}`,
|
||||
{ headers },
|
||||
);
|
||||
|
||||
libraryRequests.add(1);
|
||||
const detailsSuccess = check(agentDetailsResponse, {
|
||||
"Agent details returns 200": (r) => r.status === 200,
|
||||
"Agent details response has complete data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.id && json.name && json.graph_id;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Agent details response time < 10s": (r) =>
|
||||
r.timings.duration < 10000,
|
||||
});
|
||||
|
||||
if (detailsSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} agent details failed: ${agentDetailsResponse.status}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Step 6: Fork the library agent (simulate user customization)
|
||||
console.log(`🍴 VU ${__VU} forking agent for customization...`);
|
||||
const forkAgentResponse = http.post(
|
||||
`${BASE_URL}/api/library/agents/${addedAgentJson.id}/fork`,
|
||||
"",
|
||||
{ headers },
|
||||
);
|
||||
|
||||
libraryRequests.add(1);
|
||||
const forkSuccess = check(forkAgentResponse, {
|
||||
"Fork agent returns 200": (r) => r.status === 200,
|
||||
"Fork agent response has new agent data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.id && json.id !== addedAgentJson.id; // Should be different ID
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Fork agent response time < 15s": (r) =>
|
||||
r.timings.duration < 15000,
|
||||
});
|
||||
|
||||
if (forkSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} fork agent failed: ${forkAgentResponse.status}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.warn(
|
||||
`⚠️ VU ${__VU} failed to parse added agent response: ${e}`,
|
||||
);
|
||||
failedRequests.add(1);
|
||||
}
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} add agent failed: ${addAgentResponse.status} - ${addAgentResponse.body}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.warn(`⚠️ VU ${__VU} failed to parse store agents data: ${e}`);
|
||||
failedRequests.add(1);
|
||||
}
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} store agents request failed: ${storeAgentsResponse.status}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Step 7: Search library agents
|
||||
const searchTerms = ["automation", "api", "data", "social", "productivity"];
|
||||
const randomSearchTerm =
|
||||
searchTerms[Math.floor(Math.random() * searchTerms.length)];
|
||||
|
||||
console.log(`🔍 VU ${__VU} searching library for "${randomSearchTerm}"...`);
|
||||
const searchLibraryResponse = http.get(
|
||||
`${BASE_URL}/api/library/agents?search_term=${encodeURIComponent(randomSearchTerm)}&page=1&page_size=10`,
|
||||
{ headers },
|
||||
);
|
||||
|
||||
libraryRequests.add(1);
|
||||
const searchLibrarySuccess = check(searchLibraryResponse, {
|
||||
"Search library returns 200": (r) => r.status === 200,
|
||||
"Search library response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.agents !== undefined && Array.isArray(json.agents);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Search library response time < 10s": (r) => r.timings.duration < 10000,
|
||||
});
|
||||
|
||||
if (searchLibrarySuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} search library failed: ${searchLibraryResponse.status}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Step 8: Get library agent by graph ID (simulate finding agent by backend graph)
|
||||
if (libraryAgentsResponse.status === 200) {
|
||||
try {
|
||||
const libraryJson = libraryAgentsResponse.json();
|
||||
if (libraryJson?.agents && libraryJson.agents.length > 0) {
|
||||
const randomLibraryAgent =
|
||||
libraryJson.agents[
|
||||
Math.floor(Math.random() * libraryJson.agents.length)
|
||||
];
|
||||
|
||||
if (randomLibraryAgent?.graph_id) {
|
||||
console.log(
|
||||
`🔗 VU ${__VU} fetching agent by graph ID "${randomLibraryAgent.graph_id}"...`,
|
||||
);
|
||||
const agentByGraphResponse = http.get(
|
||||
`${BASE_URL}/api/library/agents/by-graph/${randomLibraryAgent.graph_id}`,
|
||||
{ headers },
|
||||
);
|
||||
|
||||
libraryRequests.add(1);
|
||||
const agentByGraphSuccess = check(agentByGraphResponse, {
|
||||
"Agent by graph ID returns 200": (r) => r.status === 200,
|
||||
"Agent by graph response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return (
|
||||
json &&
|
||||
json.id &&
|
||||
json.graph_id === randomLibraryAgent.graph_id
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Agent by graph response time < 10s": (r) =>
|
||||
r.timings.duration < 10000,
|
||||
});
|
||||
|
||||
if (agentByGraphSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
console.log(
|
||||
`⚠️ VU ${__VU} agent by graph request failed: ${agentByGraphResponse.status}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.warn(
|
||||
`⚠️ VU ${__VU} failed to parse library agents for graph lookup: ${e}`,
|
||||
);
|
||||
failedRequests.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
const journeyDuration = Date.now() - journeyStart;
|
||||
console.log(
|
||||
`✅ VU ${__VU} completed authenticated library journey in ${journeyDuration}ms`,
|
||||
);
|
||||
}
|
||||
|
||||
export function handleSummary(data) {
|
||||
const summary = {
|
||||
test_type: "Marketplace Library Authorized Access Load Test",
|
||||
environment: __ENV.K6_ENVIRONMENT || "DEV",
|
||||
configuration: {
|
||||
virtual_users: VUS,
|
||||
duration: DURATION,
|
||||
ramp_up: RAMP_UP,
|
||||
ramp_down: RAMP_DOWN,
|
||||
requests_per_vu: REQUESTS_PER_VU,
|
||||
},
|
||||
performance_metrics: {
|
||||
total_requests: data.metrics.http_reqs?.count || 0,
|
||||
failed_requests: data.metrics.http_req_failed?.values?.passes || 0,
|
||||
avg_response_time: data.metrics.http_req_duration?.values?.avg || 0,
|
||||
p95_response_time: data.metrics.http_req_duration?.values?.p95 || 0,
|
||||
p99_response_time: data.metrics.http_req_duration?.values?.p99 || 0,
|
||||
},
|
||||
custom_metrics: {
|
||||
library_requests: data.metrics.library_requests_total?.values?.count || 0,
|
||||
successful_requests:
|
||||
data.metrics.successful_requests_total?.values?.count || 0,
|
||||
failed_requests: data.metrics.failed_requests_total?.values?.count || 0,
|
||||
authentication_attempts:
|
||||
data.metrics.authentication_attempts_total?.values?.count || 0,
|
||||
authentication_successes:
|
||||
data.metrics.authentication_successes_total?.values?.count || 0,
|
||||
},
|
||||
thresholds_met: {
|
||||
p95_threshold:
|
||||
(data.metrics.http_req_duration?.values?.p95 || 0) < THRESHOLD_P95,
|
||||
p99_threshold:
|
||||
(data.metrics.http_req_duration?.values?.p99 || 0) < THRESHOLD_P99,
|
||||
error_rate_threshold:
|
||||
(data.metrics.http_req_failed?.values?.rate || 0) <
|
||||
THRESHOLD_ERROR_RATE,
|
||||
check_rate_threshold:
|
||||
(data.metrics.checks?.values?.rate || 0) > THRESHOLD_CHECK_RATE,
|
||||
},
|
||||
authentication_metrics: {
|
||||
auth_success_rate:
|
||||
(data.metrics.authentication_successes_total?.values?.count || 0) /
|
||||
Math.max(
|
||||
1,
|
||||
data.metrics.authentication_attempts_total?.values?.count || 0,
|
||||
),
|
||||
},
|
||||
user_journey_coverage: [
|
||||
"Authenticate with valid credentials",
|
||||
"Fetch user library agents",
|
||||
"Browse favorite library agents",
|
||||
"Discover marketplace agents",
|
||||
"Add marketplace agent to library",
|
||||
"Update agent preferences (favorites)",
|
||||
"View detailed agent information",
|
||||
"Fork agent for customization",
|
||||
"Search library agents by term",
|
||||
"Lookup agent by graph ID",
|
||||
],
|
||||
};
|
||||
|
||||
console.log("\n📚 MARKETPLACE LIBRARY AUTHORIZED TEST SUMMARY");
|
||||
console.log("==============================================");
|
||||
console.log(`Environment: ${summary.environment}`);
|
||||
console.log(`Virtual Users: ${summary.configuration.virtual_users}`);
|
||||
console.log(`Duration: ${summary.configuration.duration}`);
|
||||
console.log(`Requests per VU: ${summary.configuration.requests_per_vu}`);
|
||||
console.log(`Total Requests: ${summary.performance_metrics.total_requests}`);
|
||||
console.log(
|
||||
`Successful Requests: ${summary.custom_metrics.successful_requests}`,
|
||||
);
|
||||
console.log(`Failed Requests: ${summary.custom_metrics.failed_requests}`);
|
||||
console.log(
|
||||
`Auth Success Rate: ${Math.round(summary.authentication_metrics.auth_success_rate * 100)}%`,
|
||||
);
|
||||
console.log(
|
||||
`Average Response Time: ${Math.round(summary.performance_metrics.avg_response_time)}ms`,
|
||||
);
|
||||
console.log(
|
||||
`95th Percentile: ${Math.round(summary.performance_metrics.p95_response_time)}ms`,
|
||||
);
|
||||
console.log(
|
||||
`99th Percentile: ${Math.round(summary.performance_metrics.p99_response_time)}ms`,
|
||||
);
|
||||
|
||||
console.log("\n🎯 Threshold Status:");
|
||||
console.log(
|
||||
`P95 < ${THRESHOLD_P95}ms: ${summary.thresholds_met.p95_threshold ? "✅" : "❌"}`,
|
||||
);
|
||||
console.log(
|
||||
`P99 < ${THRESHOLD_P99}ms: ${summary.thresholds_met.p99_threshold ? "✅" : "❌"}`,
|
||||
);
|
||||
console.log(
|
||||
`Error Rate < ${THRESHOLD_ERROR_RATE * 100}%: ${summary.thresholds_met.error_rate_threshold ? "✅" : "❌"}`,
|
||||
);
|
||||
console.log(
|
||||
`Check Rate > ${THRESHOLD_CHECK_RATE * 100}%: ${summary.thresholds_met.check_rate_threshold ? "✅" : "❌"}`,
|
||||
);
|
||||
|
||||
return {
|
||||
stdout: JSON.stringify(summary, null, 2),
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,465 @@
|
||||
import { check } from "k6";
|
||||
import http from "k6/http";
|
||||
import { Counter } from "k6/metrics";
|
||||
|
||||
import { getEnvironmentConfig } from "../../configs/environment.js";
|
||||
|
||||
const config = getEnvironmentConfig();
|
||||
const BASE_URL = config.API_BASE_URL;
|
||||
|
||||
// Custom metrics
|
||||
const marketplaceRequests = new Counter("marketplace_requests_total");
|
||||
const successfulRequests = new Counter("successful_requests_total");
|
||||
const failedRequests = new Counter("failed_requests_total");
|
||||
|
||||
// HTTP error tracking
|
||||
const httpErrors = new Counter("http_errors_by_status");
|
||||
|
||||
// Enhanced error logging function
|
||||
function logHttpError(response, endpoint, method = "GET") {
|
||||
if (response.status !== 200) {
|
||||
console.error(
|
||||
`❌ VU ${__VU} ${method} ${endpoint} failed: status=${response.status}, error=${response.error || "unknown"}, body=${response.body ? response.body.substring(0, 200) : "empty"}`,
|
||||
);
|
||||
httpErrors.add(1, {
|
||||
status: response.status,
|
||||
endpoint: endpoint,
|
||||
method: method,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Test configuration
|
||||
const VUS = parseInt(__ENV.VUS) || 10;
|
||||
const DURATION = __ENV.DURATION || "2m";
|
||||
const RAMP_UP = __ENV.RAMP_UP || "30s";
|
||||
const RAMP_DOWN = __ENV.RAMP_DOWN || "30s";
|
||||
|
||||
// Performance thresholds for marketplace browsing
|
||||
const REQUEST_TIMEOUT = 60000; // 60s per request timeout
|
||||
const THRESHOLD_P95 = parseInt(__ENV.THRESHOLD_P95) || 5000; // 5s for public endpoints
|
||||
const THRESHOLD_P99 = parseInt(__ENV.THRESHOLD_P99) || 10000; // 10s for public endpoints
|
||||
const THRESHOLD_ERROR_RATE = parseFloat(__ENV.THRESHOLD_ERROR_RATE) || 0.05; // 5% error rate
|
||||
const THRESHOLD_CHECK_RATE = parseFloat(__ENV.THRESHOLD_CHECK_RATE) || 0.95; // 95% success rate
|
||||
|
||||
export const options = {
|
||||
stages: [
|
||||
{ duration: RAMP_UP, target: VUS },
|
||||
{ duration: DURATION, target: VUS },
|
||||
{ duration: RAMP_DOWN, target: 0 },
|
||||
],
|
||||
// Thresholds disabled to collect all results regardless of performance
|
||||
// thresholds: {
|
||||
// http_req_duration: [
|
||||
// { threshold: `p(95)<${THRESHOLD_P95}`, abortOnFail: false },
|
||||
// { threshold: `p(99)<${THRESHOLD_P99}`, abortOnFail: false },
|
||||
// ],
|
||||
// http_req_failed: [{ threshold: `rate<${THRESHOLD_ERROR_RATE}`, abortOnFail: false }],
|
||||
// checks: [{ threshold: `rate>${THRESHOLD_CHECK_RATE}`, abortOnFail: false }],
|
||||
// },
|
||||
tags: {
|
||||
test_type: "marketplace_public_access",
|
||||
environment: __ENV.K6_ENVIRONMENT || "DEV",
|
||||
},
|
||||
};
|
||||
|
||||
export default function () {
|
||||
console.log(`🛒 VU ${__VU} starting marketplace browsing journey...`);
|
||||
|
||||
// Simulate realistic user marketplace browsing journey
|
||||
marketplaceBrowsingJourney();
|
||||
}
|
||||
|
||||
function marketplaceBrowsingJourney() {
|
||||
const journeyStart = Date.now();
|
||||
|
||||
// Step 1: Browse marketplace homepage - get featured agents
|
||||
console.log(`🏪 VU ${__VU} browsing marketplace homepage...`);
|
||||
const featuredAgentsResponse = http.get(
|
||||
`${BASE_URL}/api/store/agents?featured=true&page=1&page_size=10`,
|
||||
);
|
||||
logHttpError(
|
||||
featuredAgentsResponse,
|
||||
"/api/store/agents?featured=true",
|
||||
"GET",
|
||||
);
|
||||
|
||||
marketplaceRequests.add(1);
|
||||
const featuredSuccess = check(featuredAgentsResponse, {
|
||||
"Featured agents endpoint returns 200": (r) => r.status === 200,
|
||||
"Featured agents response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.agents && Array.isArray(json.agents);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Featured agents responds within 60s": (r) =>
|
||||
r.timings.duration < REQUEST_TIMEOUT,
|
||||
});
|
||||
|
||||
if (featuredSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
}
|
||||
|
||||
// Step 2: Browse all agents with pagination
|
||||
console.log(`📋 VU ${__VU} browsing all agents...`);
|
||||
const allAgentsResponse = http.get(
|
||||
`${BASE_URL}/api/store/agents?page=1&page_size=20`,
|
||||
);
|
||||
logHttpError(allAgentsResponse, "/api/store/agents", "GET");
|
||||
|
||||
marketplaceRequests.add(1);
|
||||
const allAgentsSuccess = check(allAgentsResponse, {
|
||||
"All agents endpoint returns 200": (r) => r.status === 200,
|
||||
"All agents response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return (
|
||||
json &&
|
||||
json.agents &&
|
||||
Array.isArray(json.agents) &&
|
||||
json.agents.length > 0
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"All agents responds within 60s": (r) =>
|
||||
r.timings.duration < REQUEST_TIMEOUT,
|
||||
});
|
||||
|
||||
if (allAgentsSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
}
|
||||
|
||||
// Step 3: Search for specific agents
|
||||
const searchQueries = [
|
||||
"automation",
|
||||
"social media",
|
||||
"data analysis",
|
||||
"productivity",
|
||||
];
|
||||
const randomQuery =
|
||||
searchQueries[Math.floor(Math.random() * searchQueries.length)];
|
||||
|
||||
console.log(`🔍 VU ${__VU} searching for "${randomQuery}" agents...`);
|
||||
const searchResponse = http.get(
|
||||
`${BASE_URL}/api/store/agents?search_query=${encodeURIComponent(randomQuery)}&page=1&page_size=10`,
|
||||
);
|
||||
logHttpError(searchResponse, "/api/store/agents (search)", "GET");
|
||||
|
||||
marketplaceRequests.add(1);
|
||||
const searchSuccess = check(searchResponse, {
|
||||
"Search agents endpoint returns 200": (r) => r.status === 200,
|
||||
"Search agents response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.agents && Array.isArray(json.agents);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Search agents responds within 60s": (r) =>
|
||||
r.timings.duration < REQUEST_TIMEOUT,
|
||||
});
|
||||
|
||||
if (searchSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
}
|
||||
|
||||
// Step 4: Browse agents by category
|
||||
const categories = ["AI", "PRODUCTIVITY", "COMMUNICATION", "DATA", "SOCIAL"];
|
||||
const randomCategory =
|
||||
categories[Math.floor(Math.random() * categories.length)];
|
||||
|
||||
console.log(`📂 VU ${__VU} browsing "${randomCategory}" category...`);
|
||||
const categoryResponse = http.get(
|
||||
`${BASE_URL}/api/store/agents?category=${randomCategory}&page=1&page_size=15`,
|
||||
);
|
||||
logHttpError(categoryResponse, "/api/store/agents (category)", "GET");
|
||||
|
||||
marketplaceRequests.add(1);
|
||||
const categorySuccess = check(categoryResponse, {
|
||||
"Category agents endpoint returns 200": (r) => r.status === 200,
|
||||
"Category agents response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.agents && Array.isArray(json.agents);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Category agents responds within 60s": (r) =>
|
||||
r.timings.duration < REQUEST_TIMEOUT,
|
||||
});
|
||||
|
||||
if (categorySuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
}
|
||||
|
||||
// Step 5: Get specific agent details (simulate clicking on an agent)
|
||||
if (allAgentsResponse.status === 200) {
|
||||
try {
|
||||
const allAgentsJson = allAgentsResponse.json();
|
||||
if (allAgentsJson?.agents && allAgentsJson.agents.length > 0) {
|
||||
const randomAgent =
|
||||
allAgentsJson.agents[
|
||||
Math.floor(Math.random() * allAgentsJson.agents.length)
|
||||
];
|
||||
|
||||
if (randomAgent?.creator_username && randomAgent?.slug) {
|
||||
console.log(
|
||||
`📄 VU ${__VU} viewing agent details for "${randomAgent.slug}"...`,
|
||||
);
|
||||
const agentDetailsResponse = http.get(
|
||||
`${BASE_URL}/api/store/agents/${encodeURIComponent(randomAgent.creator_username)}/${encodeURIComponent(randomAgent.slug)}`,
|
||||
);
|
||||
logHttpError(
|
||||
agentDetailsResponse,
|
||||
"/api/store/agents/{creator}/{slug}",
|
||||
"GET",
|
||||
);
|
||||
|
||||
marketplaceRequests.add(1);
|
||||
const agentDetailsSuccess = check(agentDetailsResponse, {
|
||||
"Agent details endpoint returns 200": (r) => r.status === 200,
|
||||
"Agent details response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.id && json.name && json.description;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Agent details responds within 60s": (r) =>
|
||||
r.timings.duration < REQUEST_TIMEOUT,
|
||||
});
|
||||
|
||||
if (agentDetailsSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.warn(
|
||||
`⚠️ VU ${__VU} failed to parse agents data for details lookup: ${e}`,
|
||||
);
|
||||
failedRequests.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 6: Browse creators
|
||||
console.log(`👥 VU ${__VU} browsing creators...`);
|
||||
const creatorsResponse = http.get(
|
||||
`${BASE_URL}/api/store/creators?page=1&page_size=20`,
|
||||
);
|
||||
logHttpError(creatorsResponse, "/api/store/creators", "GET");
|
||||
|
||||
marketplaceRequests.add(1);
|
||||
const creatorsSuccess = check(creatorsResponse, {
|
||||
"Creators endpoint returns 200": (r) => r.status === 200,
|
||||
"Creators response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.creators && Array.isArray(json.creators);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Creators responds within 60s": (r) => r.timings.duration < REQUEST_TIMEOUT,
|
||||
});
|
||||
|
||||
if (creatorsSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
}
|
||||
|
||||
// Step 7: Get featured creators
|
||||
console.log(`⭐ VU ${__VU} browsing featured creators...`);
|
||||
const featuredCreatorsResponse = http.get(
|
||||
`${BASE_URL}/api/store/creators?featured=true&page=1&page_size=10`,
|
||||
);
|
||||
logHttpError(
|
||||
featuredCreatorsResponse,
|
||||
"/api/store/creators?featured=true",
|
||||
"GET",
|
||||
);
|
||||
|
||||
marketplaceRequests.add(1);
|
||||
const featuredCreatorsSuccess = check(featuredCreatorsResponse, {
|
||||
"Featured creators endpoint returns 200": (r) => r.status === 200,
|
||||
"Featured creators response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.creators && Array.isArray(json.creators);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Featured creators responds within 60s": (r) =>
|
||||
r.timings.duration < REQUEST_TIMEOUT,
|
||||
});
|
||||
|
||||
if (featuredCreatorsSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
}
|
||||
|
||||
// Step 8: Get specific creator details (simulate clicking on a creator)
|
||||
if (creatorsResponse.status === 200) {
|
||||
try {
|
||||
const creatorsJson = creatorsResponse.json();
|
||||
if (creatorsJson?.creators && creatorsJson.creators.length > 0) {
|
||||
const randomCreator =
|
||||
creatorsJson.creators[
|
||||
Math.floor(Math.random() * creatorsJson.creators.length)
|
||||
];
|
||||
|
||||
if (randomCreator?.username) {
|
||||
console.log(
|
||||
`👤 VU ${__VU} viewing creator details for "${randomCreator.username}"...`,
|
||||
);
|
||||
const creatorDetailsResponse = http.get(
|
||||
`${BASE_URL}/api/store/creator/${encodeURIComponent(randomCreator.username)}`,
|
||||
);
|
||||
logHttpError(
|
||||
creatorDetailsResponse,
|
||||
"/api/store/creator/{username}",
|
||||
"GET",
|
||||
);
|
||||
|
||||
marketplaceRequests.add(1);
|
||||
const creatorDetailsSuccess = check(creatorDetailsResponse, {
|
||||
"Creator details endpoint returns 200": (r) => r.status === 200,
|
||||
"Creator details response has data": (r) => {
|
||||
try {
|
||||
const json = r.json();
|
||||
return json && json.username && json.description !== undefined;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
"Creator details responds within 60s": (r) =>
|
||||
r.timings.duration < REQUEST_TIMEOUT,
|
||||
});
|
||||
|
||||
if (creatorDetailsSuccess) {
|
||||
successfulRequests.add(1);
|
||||
} else {
|
||||
failedRequests.add(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.warn(
|
||||
`⚠️ VU ${__VU} failed to parse creators data for details lookup: ${e}`,
|
||||
);
|
||||
failedRequests.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
const journeyDuration = Date.now() - journeyStart;
|
||||
console.log(
|
||||
`✅ VU ${__VU} completed marketplace browsing journey in ${journeyDuration}ms`,
|
||||
);
|
||||
}
|
||||
|
||||
export function handleSummary(data) {
|
||||
const summary = {
|
||||
test_type: "Marketplace Public Access Load Test",
|
||||
environment: __ENV.K6_ENVIRONMENT || "DEV",
|
||||
configuration: {
|
||||
virtual_users: VUS,
|
||||
duration: DURATION,
|
||||
ramp_up: RAMP_UP,
|
||||
ramp_down: RAMP_DOWN,
|
||||
},
|
||||
performance_metrics: {
|
||||
total_requests: data.metrics.http_reqs?.count || 0,
|
||||
failed_requests: data.metrics.http_req_failed?.values?.passes || 0,
|
||||
avg_response_time: data.metrics.http_req_duration?.values?.avg || 0,
|
||||
p95_response_time: data.metrics.http_req_duration?.values?.p95 || 0,
|
||||
p99_response_time: data.metrics.http_req_duration?.values?.p99 || 0,
|
||||
},
|
||||
custom_metrics: {
|
||||
marketplace_requests:
|
||||
data.metrics.marketplace_requests_total?.values?.count || 0,
|
||||
successful_requests:
|
||||
data.metrics.successful_requests_total?.values?.count || 0,
|
||||
failed_requests: data.metrics.failed_requests_total?.values?.count || 0,
|
||||
},
|
||||
thresholds_met: {
|
||||
p95_threshold:
|
||||
(data.metrics.http_req_duration?.values?.p95 || 0) < THRESHOLD_P95,
|
||||
p99_threshold:
|
||||
(data.metrics.http_req_duration?.values?.p99 || 0) < THRESHOLD_P99,
|
||||
error_rate_threshold:
|
||||
(data.metrics.http_req_failed?.values?.rate || 0) <
|
||||
THRESHOLD_ERROR_RATE,
|
||||
check_rate_threshold:
|
||||
(data.metrics.checks?.values?.rate || 0) > THRESHOLD_CHECK_RATE,
|
||||
},
|
||||
user_journey_coverage: [
|
||||
"Browse featured agents",
|
||||
"Browse all agents with pagination",
|
||||
"Search agents by keywords",
|
||||
"Filter agents by category",
|
||||
"View specific agent details",
|
||||
"Browse creators directory",
|
||||
"View featured creators",
|
||||
"View specific creator details",
|
||||
],
|
||||
};
|
||||
|
||||
console.log("\n📊 MARKETPLACE PUBLIC ACCESS TEST SUMMARY");
|
||||
console.log("==========================================");
|
||||
console.log(`Environment: ${summary.environment}`);
|
||||
console.log(`Virtual Users: ${summary.configuration.virtual_users}`);
|
||||
console.log(`Duration: ${summary.configuration.duration}`);
|
||||
console.log(`Total Requests: ${summary.performance_metrics.total_requests}`);
|
||||
console.log(
|
||||
`Successful Requests: ${summary.custom_metrics.successful_requests}`,
|
||||
);
|
||||
console.log(`Failed Requests: ${summary.custom_metrics.failed_requests}`);
|
||||
console.log(
|
||||
`Average Response Time: ${Math.round(summary.performance_metrics.avg_response_time)}ms`,
|
||||
);
|
||||
console.log(
|
||||
`95th Percentile: ${Math.round(summary.performance_metrics.p95_response_time)}ms`,
|
||||
);
|
||||
console.log(
|
||||
`99th Percentile: ${Math.round(summary.performance_metrics.p99_response_time)}ms`,
|
||||
);
|
||||
|
||||
console.log("\n🎯 Threshold Status:");
|
||||
console.log(
|
||||
`P95 < ${THRESHOLD_P95}ms: ${summary.thresholds_met.p95_threshold ? "✅" : "❌"}`,
|
||||
);
|
||||
console.log(
|
||||
`P99 < ${THRESHOLD_P99}ms: ${summary.thresholds_met.p99_threshold ? "✅" : "❌"}`,
|
||||
);
|
||||
console.log(
|
||||
`Error Rate < ${THRESHOLD_ERROR_RATE * 100}%: ${summary.thresholds_met.error_rate_threshold ? "✅" : "❌"}`,
|
||||
);
|
||||
console.log(
|
||||
`Check Rate > ${THRESHOLD_CHECK_RATE * 100}%: ${summary.thresholds_met.check_rate_threshold ? "✅" : "❌"}`,
|
||||
);
|
||||
|
||||
return {
|
||||
stdout: JSON.stringify(summary, null, 2),
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `notificationDot` on the `UserOnboarding` table. All the data in the column will be lost.
|
||||
|
||||
*/
|
||||
-- AlterEnum
|
||||
-- This migration adds more than one value to an enum.
|
||||
-- With PostgreSQL versions 11 and earlier, this is not possible
|
||||
-- in a single migration. This can be worked around by creating
|
||||
-- multiple migrations, each migration adding only one value to
|
||||
-- the enum.
|
||||
|
||||
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'RE_RUN_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'SCHEDULE_AGENT';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'RUN_3_DAYS';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'TRIGGER_WEBHOOK';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'RUN_14_DAYS';
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'RUN_AGENTS_100';
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "UserOnboarding" DROP COLUMN "notificationDot",
|
||||
ADD COLUMN "consecutiveRunDays" INTEGER NOT NULL DEFAULT 0,
|
||||
ADD COLUMN "lastRunAt" TIMESTAMP(3),
|
||||
ADD COLUMN "walletShown" BOOLEAN NOT NULL DEFAULT false;
|
||||
@@ -0,0 +1,11 @@
|
||||
-- DropIndex
|
||||
DROP INDEX "AgentGraph_userId_isActive_idx";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX "AgentGraphExecution_userId_idx";
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraph_userId_isActive_id_version_idx" ON "AgentGraph"("userId", "isActive", "id", "version");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraphExecution_userId_isDeleted_createdAt_idx" ON "AgentGraphExecution"("userId", "isDeleted", "createdAt");
|
||||
95
autogpt_platform/backend/poetry.lock
generated
95
autogpt_platform/backend/poetry.lock
generated
@@ -3451,6 +3451,99 @@ files = [
|
||||
importlib-metadata = ">=6.0,<8.8.0"
|
||||
typing-extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "orjson"
|
||||
version = "3.11.3"
|
||||
description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "orjson-3.11.3-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:29cb1f1b008d936803e2da3d7cba726fc47232c45df531b29edf0b232dd737e7"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97dceed87ed9139884a55db8722428e27bd8452817fbf1869c58b49fecab1120"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:58533f9e8266cb0ac298e259ed7b4d42ed3fa0b78ce76860626164de49e0d467"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c212cfdd90512fe722fa9bd620de4d46cda691415be86b2e02243242ae81873"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ff835b5d3e67d9207343effb03760c00335f8b5285bfceefd4dc967b0e48f6a"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f5aa4682912a450c2db89cbd92d356fef47e115dffba07992555542f344d301b"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7d18dd34ea2e860553a579df02041845dee0af8985dff7f8661306f95504ddf"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d8b11701bc43be92ea42bd454910437b355dfb63696c06fe953ffb40b5f763b4"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:90368277087d4af32d38bd55f9da2ff466d25325bf6167c8f382d8ee40cb2bbc"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fd7ff459fb393358d3a155d25b275c60b07a2c83dcd7ea962b1923f5a1134569"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f8d902867b699bcd09c176a280b1acdab57f924489033e53d0afe79817da37e6"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-win32.whl", hash = "sha256:bb93562146120bb51e6b154962d3dadc678ed0fce96513fa6bc06599bb6f6edc"},
|
||||
{file = "orjson-3.11.3-cp310-cp310-win_amd64.whl", hash = "sha256:976c6f1975032cc327161c65d4194c549f2589d88b105a5e3499429a54479770"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9d2ae0cc6aeb669633e0124531f342a17d8e97ea999e42f12a5ad4adaa304c5f"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:ba21dbb2493e9c653eaffdc38819b004b7b1b246fb77bfc93dc016fe664eac91"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00f1a271e56d511d1569937c0447d7dce5a99a33ea0dec76673706360a051904"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b67e71e47caa6680d1b6f075a396d04fa6ca8ca09aafb428731da9b3ea32a5a6"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d7d012ebddffcce8c85734a6d9e5f08180cd3857c5f5a3ac70185b43775d043d"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dd759f75d6b8d1b62012b7f5ef9461d03c804f94d539a5515b454ba3a6588038"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6890ace0809627b0dff19cfad92d69d0fa3f089d3e359a2a532507bb6ba34efb"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9d4a5e041ae435b815e568537755773d05dac031fee6a57b4ba70897a44d9d2"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2d68bf97a771836687107abfca089743885fb664b90138d8761cce61d5625d55"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:bfc27516ec46f4520b18ef645864cee168d2a027dbf32c5537cb1f3e3c22dac1"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f66b001332a017d7945e177e282a40b6997056394e3ed7ddb41fb1813b83e824"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:212e67806525d2561efbfe9e799633b17eb668b8964abed6b5319b2f1cfbae1f"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-win32.whl", hash = "sha256:6e8e0c3b85575a32f2ffa59de455f85ce002b8bdc0662d6b9c2ed6d80ab5d204"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-win_amd64.whl", hash = "sha256:6be2f1b5d3dc99a5ce5ce162fc741c22ba9f3443d3dd586e6a1211b7bc87bc7b"},
|
||||
{file = "orjson-3.11.3-cp311-cp311-win_arm64.whl", hash = "sha256:fafb1a99d740523d964b15c8db4eabbfc86ff29f84898262bf6e3e4c9e97e43e"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8c752089db84333e36d754c4baf19c0e1437012242048439c7e80eb0e6426e3b"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:9b8761b6cf04a856eb544acdd82fc594b978f12ac3602d6374a7edb9d86fd2c2"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b13974dc8ac6ba22feaa867fc19135a3e01a134b4f7c9c28162fed4d615008a"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f83abab5bacb76d9c821fd5c07728ff224ed0e52d7a71b7b3de822f3df04e15c"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e6fbaf48a744b94091a56c62897b27c31ee2da93d826aa5b207131a1e13d4064"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc779b4f4bba2847d0d2940081a7b6f7b5877e05408ffbb74fa1faf4a136c424"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd4b909ce4c50faa2192da6bb684d9848d4510b736b0611b6ab4020ea6fd2d23"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:524b765ad888dc5518bbce12c77c2e83dee1ed6b0992c1790cc5fb49bb4b6667"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:84fd82870b97ae3cdcea9d8746e592b6d40e1e4d4527835fc520c588d2ded04f"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:fbecb9709111be913ae6879b07bafd4b0785b44c1eb5cac8ac76da048b3885a1"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9dba358d55aee552bd868de348f4736ca5a4086d9a62e2bfbbeeb5629fe8b0cc"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eabcf2e84f1d7105f84580e03012270c7e97ecb1fb1618bda395061b2a84a049"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-win32.whl", hash = "sha256:3782d2c60b8116772aea8d9b7905221437fdf53e7277282e8d8b07c220f96cca"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-win_amd64.whl", hash = "sha256:79b44319268af2eaa3e315b92298de9a0067ade6e6003ddaef72f8e0bedb94f1"},
|
||||
{file = "orjson-3.11.3-cp312-cp312-win_arm64.whl", hash = "sha256:0e92a4e83341ef79d835ca21b8bd13e27c859e4e9e4d7b63defc6e58462a3710"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:af40c6612fd2a4b00de648aa26d18186cd1322330bd3a3cc52f87c699e995810"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:9f1587f26c235894c09e8b5b7636a38091a9e6e7fe4531937534749c04face43"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61dcdad16da5bb486d7227a37a2e789c429397793a6955227cedbd7252eb5a27"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:11c6d71478e2cbea0a709e8a06365fa63da81da6498a53e4c4f065881d21ae8f"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff94112e0098470b665cb0ed06efb187154b63649403b8d5e9aedeb482b4548c"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae8b756575aaa2a855a75192f356bbda11a89169830e1439cfb1a3e1a6dde7be"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c9416cc19a349c167ef76135b2fe40d03cea93680428efee8771f3e9fb66079d"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b822caf5b9752bc6f246eb08124c3d12bf2175b66ab74bac2ef3bbf9221ce1b2"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:414f71e3bdd5573893bf5ecdf35c32b213ed20aa15536fe2f588f946c318824f"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:828e3149ad8815dc14468f36ab2a4b819237c155ee1370341b91ea4c8672d2ee"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ac9e05f25627ffc714c21f8dfe3a579445a5c392a9c8ae7ba1d0e9fb5333f56e"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e44fbe4000bd321d9f3b648ae46e0196d21577cf66ae684a96ff90b1f7c93633"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-win32.whl", hash = "sha256:2039b7847ba3eec1f5886e75e6763a16e18c68a63efc4b029ddf994821e2e66b"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-win_amd64.whl", hash = "sha256:29be5ac4164aa8bdcba5fa0700a3c9c316b411d8ed9d39ef8a882541bd452fae"},
|
||||
{file = "orjson-3.11.3-cp313-cp313-win_arm64.whl", hash = "sha256:18bd1435cb1f2857ceb59cfb7de6f92593ef7b831ccd1b9bfb28ca530e539dce"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:cf4b81227ec86935568c7edd78352a92e97af8da7bd70bdfdaa0d2e0011a1ab4"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:bc8bc85b81b6ac9fc4dae393a8c159b817f4c2c9dee5d12b773bddb3b95fc07e"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-manylinux_2_34_aarch64.whl", hash = "sha256:88dcfc514cfd1b0de038443c7b3e6a9797ffb1b3674ef1fd14f701a13397f82d"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-manylinux_2_34_x86_64.whl", hash = "sha256:d61cd543d69715d5fc0a690c7c6f8dcc307bc23abef9738957981885f5f38229"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2b7b153ed90ababadbef5c3eb39549f9476890d339cf47af563aea7e07db2451"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:7909ae2460f5f494fecbcd10613beafe40381fd0316e35d6acb5f3a05bfda167"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:2030c01cbf77bc67bee7eef1e7e31ecf28649353987775e3583062c752da0077"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a0169ebd1cbd94b26c7a7ad282cf5c2744fce054133f959e02eb5265deae1872"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-win32.whl", hash = "sha256:0c6d7328c200c349e3a4c6d8c83e0a5ad029bdc2d417f234152bf34842d0fc8d"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-win_amd64.whl", hash = "sha256:317bbe2c069bbc757b1a2e4105b64aacd3bc78279b66a6b9e51e846e4809f804"},
|
||||
{file = "orjson-3.11.3-cp314-cp314-win_arm64.whl", hash = "sha256:e8f6a7a27d7b7bec81bd5924163e9af03d49bbb63013f107b48eb5d16db711bc"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:56afaf1e9b02302ba636151cfc49929c1bb66b98794291afd0e5f20fecaf757c"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:913f629adef31d2d350d41c051ce7e33cf0fd06a5d1cb28d49b1899b23b903aa"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e0a23b41f8f98b4e61150a03f83e4f0d566880fe53519d445a962929a4d21045"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d721fee37380a44f9d9ce6c701b3960239f4fb3d5ceea7f31cbd43882edaa2f"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73b92a5b69f31b1a58c0c7e31080aeaec49c6e01b9522e71ff38d08f15aa56de"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d2489b241c19582b3f1430cc5d732caefc1aaf378d97e7fb95b9e56bed11725f"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5189a5dab8b0312eadaf9d58d3049b6a52c454256493a557405e77a3d67ab7f"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9d8787bdfbb65a85ea76d0e96a3b1bed7bf0fbcb16d40408dc1172ad784a49d2"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:8e531abd745f51f8035e207e75e049553a86823d189a51809c078412cefb399a"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:8ab962931015f170b97a3dd7bd933399c1bae8ed8ad0fb2a7151a5654b6941c7"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:124d5ba71fee9c9902c4a7baa9425e663f7f0aecf73d31d54fe3dd357d62c1a7"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-win32.whl", hash = "sha256:22724d80ee5a815a44fc76274bb7ba2e7464f5564aacb6ecddaa9970a83e3225"},
|
||||
{file = "orjson-3.11.3-cp39-cp39-win_amd64.whl", hash = "sha256:215c595c792a87d4407cb72dd5e0f6ee8e694ceeb7f9102b533c5a9bf2a916bb"},
|
||||
{file = "orjson-3.11.3.tar.gz", hash = "sha256:1c0603b1d2ffcd43a411d64797a19556ef76958aef1c182f22dc30860152a98a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.2"
|
||||
@@ -7159,4 +7252,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "2c7e9370f500039b99868376021627c5a120e0ee31c5c5e6de39db2c3d82f414"
|
||||
content-hash = "b2363edeebb91f410039c8d4b563f683c1edb0cf4bda4f3e6c287040e93639bc"
|
||||
|
||||
@@ -38,6 +38,7 @@ mem0ai = "^0.1.115"
|
||||
moviepy = "^2.1.2"
|
||||
ollama = "^0.5.1"
|
||||
openai = "^1.97.1"
|
||||
orjson = "^3.10.0"
|
||||
pika = "^1.3.2"
|
||||
pinecone = "^7.3.0"
|
||||
poetry = "2.1.1" # CHECK DEPENDABOT SUPPORT BEFORE UPGRADING
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Run test data creation and update scripts in sequence.
|
||||
|
||||
Usage:
|
||||
poetry run python run_test_data.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_command(cmd: list[str], cwd: Path | None = None) -> bool:
|
||||
"""Run a command and return True if successful."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, check=True, capture_output=True, text=True, cwd=cwd
|
||||
)
|
||||
if result.stdout:
|
||||
print(result.stdout)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error running command: {' '.join(cmd)}")
|
||||
print(f"Error: {e.stderr}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function to run test data scripts."""
|
||||
print("=" * 60)
|
||||
print("Running Test Data Scripts for AutoGPT Platform")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Get the backend directory
|
||||
backend_dir = Path(__file__).parent
|
||||
test_dir = backend_dir / "test"
|
||||
|
||||
# Check if we're in the right directory
|
||||
if not (backend_dir / "pyproject.toml").exists():
|
||||
print("ERROR: This script must be run from the backend directory")
|
||||
sys.exit(1)
|
||||
|
||||
print("1. Checking database connection...")
|
||||
print("-" * 40)
|
||||
|
||||
# Import here to ensure proper environment setup
|
||||
try:
|
||||
from prisma import Prisma
|
||||
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
print("✓ Database connection successful")
|
||||
await db.disconnect()
|
||||
except Exception as e:
|
||||
print(f"✗ Database connection failed: {e}")
|
||||
print("\nPlease ensure:")
|
||||
print("1. The database services are running (docker compose up -d)")
|
||||
print("2. The DATABASE_URL in .env is correct")
|
||||
print("3. Migrations have been run (poetry run prisma migrate deploy)")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
print("2. Running test data creator...")
|
||||
print("-" * 40)
|
||||
|
||||
# Run test_data_creator.py
|
||||
if run_command(["poetry", "run", "python", "test_data_creator.py"], cwd=test_dir):
|
||||
print()
|
||||
print("✅ Test data created successfully!")
|
||||
|
||||
print()
|
||||
print("3. Running test data updater...")
|
||||
print("-" * 40)
|
||||
|
||||
# Run test_data_updater.py
|
||||
if run_command(
|
||||
["poetry", "run", "python", "test_data_updater.py"], cwd=test_dir
|
||||
):
|
||||
print()
|
||||
print("✅ Test data updated successfully!")
|
||||
else:
|
||||
print()
|
||||
print("❌ Test data updater failed!")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print()
|
||||
print("❌ Test data creator failed!")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("Test data setup completed successfully!")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("The materialized views have been populated with test data:")
|
||||
print("- mv_agent_run_counts: Agent execution statistics")
|
||||
print("- mv_review_stats: Store listing review statistics")
|
||||
print()
|
||||
print("You can now:")
|
||||
print("1. Run tests: poetry run test")
|
||||
print("2. Start the backend: poetry run serve")
|
||||
print("3. View data in the database")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -68,15 +68,23 @@ enum OnboardingStep {
|
||||
AGENT_NEW_RUN
|
||||
AGENT_INPUT
|
||||
CONGRATS
|
||||
// First Wins
|
||||
GET_RESULTS
|
||||
RUN_AGENTS
|
||||
// Marketplace
|
||||
MARKETPLACE_VISIT
|
||||
MARKETPLACE_ADD_AGENT
|
||||
MARKETPLACE_RUN_AGENT
|
||||
// Builder
|
||||
BUILDER_OPEN
|
||||
BUILDER_SAVE_AGENT
|
||||
// Consistency Challenge
|
||||
RE_RUN_AGENT
|
||||
SCHEDULE_AGENT
|
||||
RUN_AGENTS
|
||||
RUN_3_DAYS
|
||||
// The Pro Playground
|
||||
TRIGGER_WEBHOOK
|
||||
RUN_14_DAYS
|
||||
RUN_AGENTS_100
|
||||
// No longer rewarded but exist for analytical purposes
|
||||
BUILDER_OPEN
|
||||
BUILDER_RUN_AGENT
|
||||
}
|
||||
|
||||
@@ -86,7 +94,7 @@ model UserOnboarding {
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
completedSteps OnboardingStep[] @default([])
|
||||
notificationDot Boolean @default(true)
|
||||
walletShown Boolean @default(false)
|
||||
notified OnboardingStep[] @default([])
|
||||
rewardedFor OnboardingStep[] @default([])
|
||||
usageReason String?
|
||||
@@ -96,6 +104,8 @@ model UserOnboarding {
|
||||
agentInput Json?
|
||||
onboardingAgentExecutionId String?
|
||||
agentRuns Int @default(0)
|
||||
lastRunAt DateTime?
|
||||
consecutiveRunDays Int @default(0)
|
||||
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
@@ -135,7 +145,7 @@ model AgentGraph {
|
||||
StoreListingVersions StoreListingVersion[]
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
@@index([userId, isActive])
|
||||
@@index([userId, isActive, id, version])
|
||||
@@index([forkedFromId, forkedFromVersion])
|
||||
}
|
||||
|
||||
@@ -377,7 +387,7 @@ model AgentGraphExecution {
|
||||
sharedAt DateTime?
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([userId])
|
||||
@@index([userId, isDeleted, createdAt])
|
||||
@@index([createdAt])
|
||||
@@index([agentPresetId])
|
||||
@@index([shareToken])
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user